trajpred/03_track_objects_and_collec...

2.4 MiB

Use SORT tracking over a video collection and project results

Using a sort implementation originally by Alex Bewley, but adapted by Chris Fotache. For an example implementation, see his notebook.

In [1]:
import cv2
from pathlib import Path
import numpy as np
from PIL import Image
import torch
from torchvision.io.video import read_video
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
import tempfile        
In [2]:
source = Path('../DATASETS/VIRAT_subset_0102x')
videos = list(source.glob('*.mp4'))
tmpdir = Path(tempfile.gettempdir()) / 'trajpred'
tmpdir.mkdir(exist_ok=True)
In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
Out[3]:
device(type='cuda')
In [4]:
weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35)
model.to(device)
# Put the model in inference mode
model.eval()
# Get the transforms for the model's weights
preprocess = weights.transforms().to(device)

The score_thresh argument defines the threshold at which an object is detected as an object of a class. Intuitively, it's the confidence threshold, and we won't classify an object to belong to a class if the model is less than 35% confident that it belongs to a class.

The result from a single prediction coming from model(batch) looks like:

{'boxes': tensor([[5.7001e+02, 2.5786e+02, 6.3138e+02, 3.6970e+02],
         [5.0109e+02, 2.4508e+02, 5.5308e+02, 3.4852e+02],
         [3.4096e+02, 2.7015e+02, 3.6156e+02, 3.1857e+02],
         [5.0219e-01, 3.7588e+02, 9.7911e+01, 7.2000e+02],
         [3.4096e+02, 2.7015e+02, 3.6156e+02, 3.1857e+02],
         [8.3241e+01, 5.8410e+02, 1.7502e+02, 7.1743e+02]]),
 'scores': tensor([0.8525, 0.6491, 0.5985, 0.4999, 0.3753, 0.3746]),
 'labels': tensor([64, 64,  1, 64, 18, 86])}
In [5]:
%matplotlib inline


import pylab as pl
from IPython import display
from utils.timer import Timer
from sort_cfotache import Sort
import pickle


def track_video(video_path: Path) -> dict:
    tracked_instances = {}
    mot_tracker = Sort()

    video = cv2.VideoCapture(str(video_path))

    # timer = Timer()
    while True:
        # timer.tic()
        ret, frame = video.read()
        
        if not ret:
            # print("Can't receive frame (stream end?). Exiting ...")
            break

        t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        # change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
        t = t.permute(2, 0, 1)

        batch = preprocess(t)[None, :].to(device)
        # no_grad can be used on inference, should be slightly faster
        with torch.no_grad():
            predictions = model(batch)
        prediction = predictions[0] # we feed only one frame at the once

        mask = prediction['labels'] == 1 # if we want more than one: np.isin(prediction['labels'], [1,86])

        scores = prediction['scores'][mask]
        labels = prediction['labels'][mask]
        boxes = prediction['boxes'][mask]
        
        # TODO: introduce confidence and NMS supression: https://github.com/cfotache/pytorch_objectdetecttrack/blob/master/PyTorch_Object_Tracking.ipynb
        # (which I _think_ we better do after filtering)
        # alternatively look at Soft-NMS https://towardsdatascience.com/non-maximum-suppression-nms-93ce178e177c

        
        #  dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
        detections = np.array([np.append(bbox, [score, label]) for bbox, score, label in zip(boxes.cpu(), scores.cpu(), labels.cpu())])
        # print(detections)
        tracks = mot_tracker.update(detections)

        # now convert back to boxes and labels
        # print(tracks)
        boxes = np.array([t[:4] for t in tracks])
        # initialize empty with the necesserary dimensions for drawing_bounding_boxes glitch
        t_boxes = torch.from_numpy(boxes) if len(boxes) else torch.Tensor().new_empty([0, 6])
        labels = [str(int(t[4])) for t in tracks]
        # print(t_boxes, boxes, labels)



        for track in tracks:
            yield track
            

    #     print("time for frame: ", timer.toc(), ", avg:", 1/timer.average_time, "fps")

    #     display.clear_output(wait=True)

    # return tracked_instances
/home/ruben/suspicion/trajpred/sort_cfotache.py:36: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  def iou(bb_test,bb_gt):
In [6]:
def track_videos(video_paths: list[Path]) -> dict:
    # collect instances of all videos with unique key
    video_paths = list(video_paths)
    tracked_instances = {}
    timer = Timer()
    for i, p in enumerate(video_paths):
        print(f"{i}/{len(video_paths)}: {p}")

        cachefile = tmpdir / (p.name + '.pcl')
        if cachefile.exists():
            print('\tLoad pickle')
            with cachefile.open('rb') as fp:
             new_instances = pickle.load(fp)
        else:
            #continue # to quickly test from cache
            new_instances = {}
            timer.tic()
            for track in track_video(p):
                track_id = f"{i}_{str(int(track[4]))}"
                if track_id not in new_instances:
                    new_instances[track_id] = []
                new_instances[track_id].append(track)
            with cachefile.open('wb') as fp:
                pickle.dump(new_instances, fp)
            print(" time for video: ", timer.toc())
        tracked_instances.update(new_instances)
        
    return tracked_instances
In [7]:
tracked_instances = track_videos(videos)
len(tracked_instances)
0/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4
	Load pickle
1/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_09_001285_001336.mp4
	Load pickle
2/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_08_000895_000975.mp4
	Load pickle
3/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010205_04_000545_000576.mp4
	Load pickle
4/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_04_000929_000954.mp4
	Load pickle
5/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_10_000923_000959.mp4
	Load pickle
6/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010205_06_000830_000904.mp4
	Load pickle
7/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_08_001308_001332.mp4
	Load pickle
8/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_09_001484_001510.mp4
	Load pickle
9/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_00_000047_000139.mp4
	Load pickle
10/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010205_03_000370_000395.mp4
	Load pickle
11/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010206_02_000414_000439.mp4
	Load pickle
12/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_03_000865_000911.mp4
	Load pickle
13/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010208_09_000857_000886.mp4
	Load pickle
14/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_09_001010_001036.mp4
	Load pickle
15/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_00_000000_000053.mp4
	Load pickle
16/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_05_000499_000527.mp4
	Load pickle
17/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_03_000400_000435.mp4
	Load pickle
18/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_08_000705_000739.mp4
	Load pickle
19/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_01_000712_000752.mp4
	Load pickle
20/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010208_06_000671_000744.mp4
	Load pickle
21/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_05_000856_000890.mp4
	Load pickle
22/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_06_000620_000760.mp4
	Load pickle
23/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_04_000374_000469.mp4
	Load pickle
24/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_03_000270_000359.mp4
 time for video:  76.47440218925476
25/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_04_000646_000754.mp4
 time for video:  84.25160992145538
26/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010202_00_000001_000033.mp4
 time for video:  62.6530507405599
27/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_08_000838_000867.mp4
 time for video:  51.79480332136154
28/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_03_000606_000632.mp4
 time for video:  44.33411946296692
29/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010205_02_000301_000345.mp4
 time for video:  43.13727605342865
30/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_05_000515_000593.mp4
 time for video:  45.87533599989755
31/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010205_01_000207_000288.mp4
 time for video:  48.75653102993965
32/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_07_000942_000989.mp4
 time for video:  47.200045612123276
33/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_02_000167_000197.mp4
 time for video:  44.192170357704164
34/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_05_001013_001038.mp4
 time for video:  41.411013711582534
35/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010202_06_000784_000873.mp4
 time for video:  43.82025545835495
36/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_02_000347_000397.mp4
 time for video:  43.30084228515625
37/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_01_000072_000225.mp4
 time for video:  49.74186216081892
38/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_02_000349_000398.mp4
 time for video:  48.90173446337382
39/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010206_01_000124_000206.mp4
 time for video:  50.23752883076668
40/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010202_02_000161_000189.mp4
 time for video:  48.407087087631226
41/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010206_03_000546_000580.mp4
 time for video:  47.02203807565901
42/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_06_000702_000744.mp4
 time for video:  46.14274973618357
43/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010208_00_000000_000049.mp4
 time for video:  45.96134517192841
44/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_01_000125_000152.mp4
 time for video:  44.89160401480539
45/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_04_000568_000620.mp4
 time for video:  44.69535830887881
46/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_07_000748_000837.mp4
 time for video:  46.03415718285934
47/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_07_001195_001260.mp4
 time for video:  46.31604018807411
48/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_07_000601_000697.mp4
 time for video:  47.671081705093385
49/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_09_000886_000915.mp4
 time for video:  46.78188127737779
50/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_10_001372_001395.mp4
 time for video:  45.63375103032148
51/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_05_000658_000700.mp4
 time for video:  45.015538905348095
52/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010202_03_000313_000355.mp4
 time for video:  44.412532288452674
53/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_02_000790_000816.mp4
 time for video:  43.610562173525494
54/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_04_000457_000511.mp4
 time for video:  43.67899775505066
55/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010202_01_000055_000147.mp4
 time for video:  44.6242256090045
56/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_11_001524_001607.mp4
 time for video:  44.944525480270386
57/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_01_000254_000322.mp4
 time for video:  45.192202406771045
58/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010208_05_000591_000631.mp4
 time for video:  44.800596714019775
59/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010208_10_000904_000991.mp4
 time for video:  45.58240665329827
60/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_06_001064_001097.mp4
 time for video:  44.88245353827605
61/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010208_02_000150_000180.mp4
 time for video:  44.24212918156072
62/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010206_04_000720_000767.mp4
 time for video:  43.92639535512679
63/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010208_08_000807_000831.mp4
 time for video:  43.20958806276322
64/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_10_001092_001121.mp4
 time for video:  42.71179392279648
65/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010206_05_000797_000823.mp4
 time for video:  42.104674679892405
66/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_06_000550_000600.mp4
 time for video:  42.07939862650494
67/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010201_09_000770_000801.mp4
 time for video:  41.708983686837286
68/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010206_00_000007_000035.mp4
 time for video:  41.25493524339464
69/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_00_000030_000059.mp4
 time for video:  40.88401547203893
70/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010208_03_000201_000232.mp4
 time for video:  40.45448579686753
71/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010204_06_000913_000939.mp4
 time for video:  39.99556113779545
72/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010207_10_001549_001596.mp4
 time for video:  39.92924086415038
73/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010203_07_000775_000869.mp4
 time for video:  40.64969404220581
74/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_03_000470_000567.mp4
 time for video:  41.35395318854089
75/76: ../DATASETS/VIRAT_subset_0102x/VIRAT_S_010208_07_000768_000791.mp4
 time for video:  40.88396269083023
Out[7]:
5952

Project / Homography

Now that all trajectories are captured (for a single video), these can then be projected onto a flat surface by homography). The necessary $H$ matrix is already provided by VIRAT in the homographies folder of their online data repository.

In [8]:
homography = list(source.glob('*img2world.txt'))[0]
H = np.loadtxt(homography, delimiter=',')

The homography matrix helps to transform points from image space to a flat world plane. The README_homography.txt from VIRAT describes:

Roughly estimated 3-by-3 homographies are included for convenience. Each homography H provides a mapping from image coordinate to scene-dependent world coordinate.
[xw,yw,zw]' = H*[xi,yi,1]'

xi: horizontal axis on image with left top corner as origin, increases right. yi: vertical axis on image with left top corner as origin, increases downward.

xw/zw: world x coordinate yw/zw: world y coordiante

In [9]:
print(Image.open("../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.png").size)
Image.open("../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.png")
(1200, 900)
Out[9]:
In [10]:
from matplotlib import pyplot as plt

fig = plt.figure(figsize=(20,8))
ax1, ax2 = fig.subplots(1,2)

ax1.set_aspect(1)
ax2.imshow(Image.open("../DATASETS/VIRAT_subset_0102x/VIRAT_S_0102.jpg"))

for track_id in tracked_instances:
    # print(track_id)
    bboxes = tracked_instances[track_id]
    traj = np.array([[[0.5 * (det[0]+det[2]), det[3]]] for det in bboxes])
    projected_traj = cv2.perspectiveTransform(traj,H)
    # plt.plot(projected_traj[:,0])
    ax1.plot(projected_traj[:,:,0].reshape(-1), projected_traj[:,:,1].reshape(-1))
    ax2.plot(traj[:,:,0].reshape(-1), traj[:,:,1].reshape(-1))
    
plt.show()

What if the projection is a heatmap of where people are (tracking would not really be necessary for this thoug). Using the above plot and some blurring effects of pyplot from their documentation

In [37]:
from matplotlib import gridspec
import matplotlib.cm as cm
import matplotlib.transforms as mtransforms
from matplotlib.colors import LightSource
from matplotlib.artist import Artist


def smooth1d(x, window_len):
    # copied from https://scipy-cookbook.readthedocs.io/items/SignalSmooth.html
    s = np.r_[2*x[0] - x[window_len:1:-1], x, 2*x[-1] - x[-1:-window_len:-1]]
    w = np.hanning(window_len)
    y = np.convolve(w/w.sum(), s, mode='same')
    return y[window_len-1:-window_len+1]


def smooth2d(A, sigma=3):
    window_len = max(int(sigma), 3) * 2 + 1
    A = np.apply_along_axis(smooth1d, 0, A, window_len)
    A = np.apply_along_axis(smooth1d, 1, A, window_len)
    return A


class BaseFilter:

    def get_pad(self, dpi):
        return 0

    def process_image(self, padded_src, dpi):
        raise NotImplementedError("Should be overridden by subclasses")

    def __call__(self, im, dpi):
        pad = self.get_pad(dpi)
        padded_src = np.pad(im, [(pad, pad), (pad, pad), (0, 0)], "constant")
        tgt_image = self.process_image(padded_src, dpi)
        return tgt_image, -pad, -pad



class GaussianFilter(BaseFilter):
    """Simple Gaussian filter."""

    def __init__(self, sigma, alpha=0.5, color=(0, 0, 0)):
        self.sigma = sigma
        self.alpha = alpha
        self.color = color

    def get_pad(self, dpi):
        return int(self.sigma*3 / 72 * dpi)

    def process_image(self, padded_src, dpi):
        tgt_image = np.empty_like(padded_src)
        tgt_image[:, :, :3] = self.color
        tgt_image[:, :, 3] = smooth2d(padded_src[:, :, 3] * self.alpha,
                                      self.sigma / 72 * dpi)
        return tgt_image

gauss = GaussianFilter(2)

fig = plt.figure(figsize=(20,12))


# Create 2x2 sub plots
gs = gridspec.GridSpec(2, 2)

# (ax1, ax2), (ax3, ax4) = fig.subplots(2,2)
ax1 = fig.add_subplot(gs[0,0])
ax3 = fig.add_subplot(gs[1,0])
ax2 = fig.add_subplot(gs[:,1])

ax1.set_aspect(1)
ax3.set_aspect(1)

ax2.imshow(Image.open("../DATASETS/VIRAT_subset_0102x/VIRAT_S_0102.jpg"))

for track_id in tracked_instances:
    # print(track_id)
    bboxes = tracked_instances[track_id]
    traj = np.array([[[0.5 * (det[0]+det[2]), det[3]]] for det in bboxes])
    projected_traj = cv2.perspectiveTransform(traj,H)
    # plt.plot(projected_traj[:,0])
    line, = ax1.plot(projected_traj[:,:,0].reshape(-1), projected_traj[:,:,1].reshape(-1), color=(0,0,0,0.05))
    line.set_agg_filter(gauss)
    line.set_rasterized(True) # "to suport mixed-mode renderers"

    points = ax3.scatter(projected_traj[:,:,0].reshape(-1), projected_traj[:,:,1].reshape(-1), color=(0,0,0,0.01))
    points.set_agg_filter(gauss)
    points.set_rasterized(True) # "to suport mixed-mode renderers"

    ax2.plot(traj[:,:,0].reshape(-1), traj[:,:,1].reshape(-1))
    
plt.show()