diff --git a/trap/config.py b/trap/config.py index a92e8bc..f992d15 100644 --- a/trap/config.py +++ b/trap/config.py @@ -2,6 +2,8 @@ import argparse from pathlib import Path import types +from trap.tracker import DETECTORS + from pyparsing import Optional class LambdaParser(argparse.ArgumentParser): @@ -208,6 +210,10 @@ tracker_parser.add_argument("--save-for-training", help="Specify the path in which to save", type=Path, default=None) +tracker_parser.add_argument("--detector", + help="Specify the detector to use", + type=str, + choices=DETECTORS) # Renderer diff --git a/trap/tracker.py b/trap/tracker.py index 7b0a869..d82672c 100644 --- a/trap/tracker.py +++ b/trap/tracker.py @@ -1,5 +1,7 @@ from argparse import Namespace +from collections import defaultdict import csv +from dataclasses import dataclass, field import json import logging from multiprocessing import Event @@ -11,9 +13,12 @@ import torch import zmq import cv2 -from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights +from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights from deep_sort_realtime.deepsort_tracker import DeepSort -from deep_sort_realtime.deep_sort.track import Track +from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack + +from ultralytics import YOLO +from ultralytics.engine.results import Results as YOLOResult from trap.frame_emitter import Frame @@ -28,6 +33,33 @@ TARGET_DT = .1 logger = logging.getLogger("trap.tracker") +DETECTOR_RESNET = 'resnet' +DETECTOR_YOLOv8 = 'ultralytics' + +DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8] + +@dataclass +class Track: + track_id: str = None + history: [Detection]= field(default_factory=lambda: []) + +@dataclass +class Detection: + track_id: str + l: int # left + t: int # top + w: int # width + h: int # height + + def get_foot_coords(self): + return [self.l + 0.5 * self.w, self.t+self.h] + + @classmethod + def from_deepsort(cls, dstrack: DeepsortTrack): + return cls(dstrack.track_id, *dstrack.to_ltwh()) + + + class Tracker: def __init__(self, config: Namespace, is_running: Event): self.config = config @@ -46,20 +78,30 @@ class Tracker: # # TODO: config device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT - self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35) - self.model.to(self.device) - # Put the model in inference mode - self.model.eval() - # Get the transforms for the model's weights - self.preprocess = weights.transforms().to(self.device) + # TODO: support removal + self.tracks = defaultdict(lambda: Track()) + + if self.config.detector == DETECTOR_RESNET: + # weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT + # self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2) + weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT + self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20) + self.model.to(self.device) + # Put the model in inference mode + self.model.eval() + # Get the transforms for the model's weights + self.preprocess = weights.transforms().to(self.device) + elif self.config.detector == DETECTOR_YOLOv8: + self.model = YOLO('EXPERIMENTS/yolov8x.pt') + else: + raise RuntimeError("No valid detector specified. See --help") # homography = list(source.glob('*img2world.txt'))[0] self.H = np.loadtxt(self.config.homography, delimiter=',') - self.mot_tracker = DeepSort(max_age=5) + self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9) logger.debug("Set up tracker") @@ -94,30 +136,46 @@ class Tracker: # logger.info(f"Frame delivery delay = {time.time()-frame.time}s") start_time = time.time() - detections = self.detect_persons(frame.img) - tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img) - TEMP_boxes = [t.to_ltwh() for t in tracks] - TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes]) - if len(TEMP_coords): - TEMP_proj_coords = cv2.perspectiveTransform(TEMP_coords,self.H) + if self.config.detector == DETECTOR_YOLOv8: + detections: [Detection] = self._yolov8_track(frame.img) + else : + detections: [Detection] = self._resnet_track(frame.img) + + + # Store detections into tracklets + for detection in detections: + track = self.tracks[detection.track_id] + track.track_id = detection.track_id # for new tracks + + track.history.append(detection) + # if len(track.history) > 30: # retain 90 tracks for 90 frames + # track.history.pop(0) + + foot_coordinates = np.array([[t.get_foot_coords()] for t in detections]) + if len(foot_coordinates): + projected_coordinates = cv2.perspectiveTransform(foot_coordinates,self.H) else: - TEMP_proj_coords = [] + projected_coordinates = [] # print(TEMP_proj_coords) trajectories = {} - for i, coords in enumerate(TEMP_proj_coords): - tid = tracks[i].track_id + for detection, coords in zip(detections, projected_coordinates): + tid = str(detection.track_id) trajectories[tid] = { "id": tid, - "history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test + "history": [{"x":c[0], "y":c[1]} for c in coords] if not self.config.bypass_prediction else coords.tolist() # already doubles nested, fine for test } - # logger.debug(f"{trajectories}") + # logger.info(f"{trajectories}") frame.trajectories = trajectories current_time = time.time() logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)") - self.trajectory_socket.send(pickle.dumps(frame)) + if self.config.bypass_prediction: + self.trajectory_socket.send_string(json.dumps(trajectories)) + else: + self.trajectory_socket.send(pickle.dumps(frame)) + # self.trajectory_socket.send_string(json.dumps(trajectories)) # provide a {ID: {id: ID, history: [[x,y],[x,y],...]}} # TODO: provide a track object that actually keeps history (unlike tracker) @@ -154,8 +212,16 @@ class Tracker: logger.info('Stopping') + def _yolov8_track(self, img) -> [Detection]: + results: [YOLOResult] = self.model.track(img, persist=True) + return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())] + + def _resnet_track(self, img) -> [Detection]: + detections = self._resnet_detect_persons(img) + tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img) + return [Detection.from_deepsort(t) for t in tracks] - def detect_persons(self, frame) -> Detections: + def _resnet_detect_persons(self, frame) -> Detections: t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C) t = t.permute(2, 0, 1)