trajpred/02_track_objects.ipynb

18 KiB

In [1]:
import cv2
from pathlib import Path
import numpy as np
# from PIL import Image
import torch
from torchvision.io.video import read_video
import matplotlib.pyplot as plt
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
In [2]:
source = Path('../DATASETS/VIRAT_subset_0102x')
videos = source.glob('*.mp4')
homography = list(source.glob('*img2world.txt'))[0]
H = np.loadtxt(homography, delimiter=',')

The homography matrix helps to transform points from image space to a flat world plane. The README_homography.txt from VIRAT describes:

Roughly estimated 3-by-3 homographies are included for convenience. Each homography H provides a mapping from image coordinate to scene-dependent world coordinate.
[xw,yw,zw]' = H*[xi,yi,1]'

xi: horizontal axis on image with left top corner as origin, increases right. yi: vertical axis on image with left top corner as origin, increases downward.

xw/zw: world x coordinate yw/zw: world y coordiante

In [3]:
# H.dot(np.array([20,300, 1]))
In [4]:
video_path = list(videos)[0]
video_path = Path("../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4")
In [5]:
video_path
Out[5]:
PosixPath('../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4')
In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
Out[34]:
device(type='cuda')
In [37]:
weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35)
model.to(device)
# Put the model in inference mode
model.eval()
# Get the transforms for the model's weights
preprocess = weights.transforms().to(device)
In [38]:
# hub.set_dir()
In [39]:
video = cv2.VideoCapture(str(video_path))

The score_thresh argument defines the threshold at which an object is detected as an object of a class. Intuitively, it's the confidence threshold, and we won't classify an object to belong to a class if the model is less than 35% confident that it belongs to a class.

The result from a single prediction coming from model(batch) looks like:

{'boxes': tensor([[5.7001e+02, 2.5786e+02, 6.3138e+02, 3.6970e+02],
         [5.0109e+02, 2.4508e+02, 5.5308e+02, 3.4852e+02],
         [3.4096e+02, 2.7015e+02, 3.6156e+02, 3.1857e+02],
         [5.0219e-01, 3.7588e+02, 9.7911e+01, 7.2000e+02],
         [3.4096e+02, 2.7015e+02, 3.6156e+02, 3.1857e+02],
         [8.3241e+01, 5.8410e+02, 1.7502e+02, 7.1743e+02]]),
 'scores': tensor([0.8525, 0.6491, 0.5985, 0.4999, 0.3753, 0.3746]),
 'labels': tensor([64, 64,  1, 64, 18, 86])}

Now with SORT tracking

Using a sort implementation originally by Alex Bewley, but adapted by Chris Fotache. For an example implementation, see his notebook.

In [56]:
from sort_cfotache import Sort

mot_tracker = Sort()

display_image = True
In [57]:
tracked_instances = {}
In [58]:
# TODO make into loop
%matplotlib inline


import pylab as pl
from IPython import display
from utils.timer import Timer

i=0
timer = Timer()
while True:
    timer.tic()
    ret, frame = video.read()
    i+=1
    
    if not ret:
        print("Can't receive frame (stream end?). Exiting ...")
        break

    t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    t.shape
    # image = image[np.newaxis, :] 
    t = t.permute(2, 0, 1)
    t.shape

    batch = preprocess(t)[None, :].to(device)
    # no_grad can be used on inference, should be slightly faster
    with torch.no_grad():
        predictions = model(batch)
    prediction = predictions[0] # we feed only one frame at the once

    mask = prediction['labels'] == 1 # if we want more than one: np.isin(prediction['labels'], [1,86])

    scores = prediction['scores'][mask]
    labels = prediction['labels'][mask]
    boxes = prediction['boxes'][mask]
    
    # TODO: introduce confidence and NMS supression: https://github.com/cfotache/pytorch_objectdetecttrack/blob/master/PyTorch_Object_Tracking.ipynb
    # (which I _think_ we better do after filtering)
    # alternatively look at Soft-NMS https://towardsdatascience.com/non-maximum-suppression-nms-93ce178e177c

    
    #  dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
    detections = np.array([np.append(bbox, [score, label]) for bbox, score, label in zip(boxes.cpu(), scores.cpu(), labels.cpu())])
    # print(detections)
    tracks = mot_tracker.update(detections)

    # now convert back to boxes and labels
    # print(tracks)
    boxes = np.array([t[:4] for t in tracks])
    # initialize empty with the necesserary dimensions for drawing_bounding_boxes glitch
    t_boxes = torch.from_numpy(boxes) if len(boxes) else torch.Tensor().new_empty([0, 6])
    labels = [str(int(t[4])) for t in tracks]
    # print(t_boxes, boxes, labels)


    for track in tracks:
        # TODO add to tracked_instances
        track_id = str(int(track[4]))
        if track_id not in tracked_instances:
            tracked_instances[track_id] = []
        tracked_instances[track_id].append(track)

    
    # labels = [weights.meta["categories"][i] for i in labels]

    if display_image:
        box = draw_bounding_boxes(t, boxes=t_boxes,
                                labels=labels,
                                colors="cyan",
                                width=2, 
                                font_size=30,
                                # font='Arial'
                                )

        im = to_pil_image(box.detach())

        display.display(im, f"frame {i}")
    print(prediction)
    print("time for frame: ", timer.toc(), ", avg:", 1/timer.average_time, "fps")

    display.clear_output(wait=True)

    # break # for now
    # pl.clf()
    # # pl.plot(pl.randn(100))
    # pl.figure(figsize=(24,50))
    # # fig.axes[0].imshow(img)
    # pl.imshow(im)
    # display.display(pl.gcf(), f"frame {i}")
    # display.clear_output(wait=True)
    # time.sleep(1.0)

    # fig, ax = plt.subplots(figsize=(16, 12))
    # ax.imshow(im)
    # plt.show()
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[58], line 29
     27 # no_grad can be used on inference, should be slightly faster
     28 with torch.no_grad():
---> 29     predictions = model(batch)
     30 prediction = predictions[0] # we feed only one frame at the once
     32 mask = prediction['labels'] == 1 # if we want more than one: np.isin(prediction['labels'], [1,86])

File ~/suspicion/trajpred/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/suspicion/trajpred/.venv/lib/python3.9/site-packages/torchvision/models/detection/retinanet.py:663, in RetinaNet.forward(self, images, targets)
    660     split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]
    662     # compute the detections
--> 663     detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
    664     detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
    666 if torch.jit.is_scripting():

File ~/suspicion/trajpred/.venv/lib/python3.9/site-packages/torchvision/models/detection/retinanet.py:531, in RetinaNet.postprocess_detections(self, head_outputs, anchors, image_shapes)
    529 scores_per_level = torch.sigmoid(logits_per_level).flatten()
    530 keep_idxs = scores_per_level > self.score_thresh
--> 531 scores_per_level = scores_per_level[keep_idxs]
    532 topk_idxs = torch.where(keep_idxs)[0]
    534 # keep only topk scoring predictions

KeyboardInterrupt: 
In [55]:
 
Out[55]:
dict_keys(['22', '24', '26', '27', '30', '31', '32', '33', '37'])
In [ ]:
 
Out[ ]:
(array([[5.30405334e+02, 5.34641296e+02, 6.03237061e+02, 7.18612122e+02,
         9.42070127e-01, 1.00000000e+00],
        [4.61479340e+02, 5.49811340e+02, 5.34607056e+02, 7.17237122e+02,
         9.26090062e-01, 1.00000000e+00],
        [3.38673218e+02, 2.55078461e+02, 3.57062561e+02, 2.95217896e+02,
         6.61470771e-01, 1.00000000e+00]]),)
In [24]:
 
Out[24]:
{'17': [array([573.00909697, 551.76122438, 657.56378982, 720.05069192,
          17.        ,   1.        ]),
  array([570.16715738, 550.85464258, 652.59986304, 719.88004284,
          17.        ,   1.        ]),
  array([568.02909891, 550.10706805, 649.96206622, 720.03113806,
          17.        ,   1.        ]),
  array([562.49451695, 549.06638446, 644.29895964, 720.04103925,
          17.        ,   1.        ])],
 '13': [array([337.63475088, 255.66774475, 355.97561492, 296.69147428,
          13.        ,   1.        ]),
  array([337.77042983, 255.72223676, 356.05113319, 296.63698388,
          13.        ,   1.        ]),
  array([338.02427059, 255.89595935, 356.25536645, 296.58306741,
          13.        ,   1.        ]),
  array([338.1632419 , 255.82719651, 356.27227032, 296.33234513,
          13.        ,   1.        ])],
 '12': [array([481.57704931, 568.79192296, 570.79284909, 718.23349465,
          12.        ,   1.        ]),
  array([479.96268827, 569.31456975, 567.89464999, 718.91657277,
          12.        ,   1.        ]),
  array([478.23383288, 568.93539717, 565.05653529, 718.92571522,
          12.        ,   1.        ]),
  array([475.43950486, 567.4295262 , 561.46362594, 718.3620136 ,
          12.        ,   1.        ])]}
In [25]:
 
Out[25]:
True
In [ ]: