83 lines
3.3 KiB
Python
83 lines
3.3 KiB
Python
from argparse import Namespace
|
|
import numpy as np
|
|
import torch
|
|
import zmq
|
|
import cv2
|
|
|
|
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
|
|
from deep_sort_realtime.deepsort_tracker import DeepSort
|
|
|
|
Detection = [int, int, int, int, float, int]
|
|
Detections = [Detection]
|
|
|
|
class Tracker:
|
|
def __init__(self, config: Namespace):
|
|
|
|
context = zmq.Context()
|
|
self.frame_sock = context.socket(zmq.SUB)
|
|
self.frame_sock.bind(config.zmq_frame_addr)
|
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
|
|
|
# TODO: config device
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
|
self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35)
|
|
self.model.to(self.device)
|
|
# Put the model in inference mode
|
|
self.model.eval()
|
|
# Get the transforms for the model's weights
|
|
self.preprocess = weights.transforms().to(self.device)
|
|
|
|
self.mot_tracker = DeepSort(max_age=5)
|
|
|
|
|
|
def track(self):
|
|
while True:
|
|
frame = self.frame_sock.recv()
|
|
detections = self.detect_persons(frame)
|
|
tracks = self.mot_tracker.update_tracks(detections, frame=frame)
|
|
|
|
# TODO: provide a track object that actually keeps history (unlike tracker)
|
|
|
|
|
|
def detect_persons(self, frame) -> Detections:
|
|
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
|
t = t.permute(2, 0, 1)
|
|
|
|
batch = self.preprocess(t)[None, :].to(self.device)
|
|
# no_grad can be used on inference, should be slightly faster
|
|
with torch.no_grad():
|
|
predictions = self.model(batch)
|
|
prediction = predictions[0] # we feed only one frame at once
|
|
|
|
# TODO: check if we need e.g. cyclist
|
|
mask = prediction['labels'] == 1 # if we want more than one label: np.isin(prediction['labels'], [1,86])
|
|
|
|
scores = prediction['scores'][mask]
|
|
labels = prediction['labels'][mask]
|
|
boxes = prediction['boxes'][mask]
|
|
|
|
# TODO: introduce confidence and NMS supression: https://github.com/cfotache/pytorch_objectdetecttrack/blob/master/PyTorch_Object_Tracking.ipynb
|
|
# (which I _think_ we better do after filtering)
|
|
# alternatively look at Soft-NMS https://towardsdatascience.com/non-maximum-suppression-nms-93ce178e177c
|
|
|
|
# dets - a numpy array of detections in the format [[x1,y1,x2,y2,score, label],[x1,y1,x2,y2,score, label],...]
|
|
detections = np.array([np.append(bbox, [score, label]) for bbox, score, label in zip(boxes.cpu(), scores.cpu(), labels.cpu())])
|
|
detections = self.detect_persons_deepsort_wrapper(detections)
|
|
|
|
return detections
|
|
|
|
@classmethod
|
|
def detect_persons_deepsort_wrapper(detections):
|
|
"""make detect_persons() compatible with
|
|
deep_sort_realtime tracker by going from ltrb to ltwh and
|
|
different nesting
|
|
"""
|
|
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
|
|
|
|
|
|
def run_tracker(config: Namespace):
|
|
router = Tracker(config)
|
|
router.track() |