trap/trap/tracker.py

83 lines
3.3 KiB
Python

from argparse import Namespace
import numpy as np
import torch
import zmq
import cv2
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
from deep_sort_realtime.deepsort_tracker import DeepSort
Detection = [int, int, int, int, float, int]
Detections = [Detection]
class Tracker:
def __init__(self, config: Namespace):
context = zmq.Context()
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.bind(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
# TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35)
self.model.to(self.device)
# Put the model in inference mode
self.model.eval()
# Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(max_age=5)
def track(self):
while True:
frame = self.frame_sock.recv()
detections = self.detect_persons(frame)
tracks = self.mot_tracker.update_tracks(detections, frame=frame)
# TODO: provide a track object that actually keeps history (unlike tracker)
def detect_persons(self, frame) -> Detections:
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
t = t.permute(2, 0, 1)
batch = self.preprocess(t)[None, :].to(self.device)
# no_grad can be used on inference, should be slightly faster
with torch.no_grad():
predictions = self.model(batch)
prediction = predictions[0] # we feed only one frame at once
# TODO: check if we need e.g. cyclist
mask = prediction['labels'] == 1 # if we want more than one label: np.isin(prediction['labels'], [1,86])
scores = prediction['scores'][mask]
labels = prediction['labels'][mask]
boxes = prediction['boxes'][mask]
# TODO: introduce confidence and NMS supression: https://github.com/cfotache/pytorch_objectdetecttrack/blob/master/PyTorch_Object_Tracking.ipynb
# (which I _think_ we better do after filtering)
# alternatively look at Soft-NMS https://towardsdatascience.com/non-maximum-suppression-nms-93ce178e177c
# dets - a numpy array of detections in the format [[x1,y1,x2,y2,score, label],[x1,y1,x2,y2,score, label],...]
detections = np.array([np.append(bbox, [score, label]) for bbox, score, label in zip(boxes.cpu(), scores.cpu(), labels.cpu())])
detections = self.detect_persons_deepsort_wrapper(detections)
return detections
@classmethod
def detect_persons_deepsort_wrapper(detections):
"""make detect_persons() compatible with
deep_sort_realtime tracker by going from ltrb to ltwh and
different nesting
"""
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
def run_tracker(config: Namespace):
router = Tracker(config)
router.track()