diff --git a/trap/tracker.py b/trap/tracker.py index b794995..8d5c381 100644 --- a/trap/tracker.py +++ b/trap/tracker.py @@ -13,8 +13,9 @@ import torch import zmq import cv2 -from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights +from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights from deep_sort_realtime.deepsort_tracker import DeepSort +from torchvision.models import ResNet50_Weights from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack from ultralytics import YOLO @@ -22,8 +23,8 @@ from ultralytics.engine.results import Results as YOLOResult from trap.frame_emitter import Frame -Detection = [int, int, int, int, float, int] -Detections = [Detection] +# Detection = [int, int, int, int, float, int] +# Detections = [Detection] # This is the dt that is also used by the scene. # as this needs to be rather stable, try to adhere @@ -33,11 +34,33 @@ TARGET_DT = .1 logger = logging.getLogger("trap.tracker") -DETECTOR_RESNET = 'resnet' +DETECTOR_RETINANET = 'retinanet' +DETECTOR_MASKRCNN = 'maskrcnn' +DETECTOR_FASTERRCNN = 'fasterrcnn' DETECTOR_YOLOv8 = 'ultralytics' -DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8] +DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8] +@dataclass +class Detection: + track_id: str + l: int # left + t: int # top + w: int # width + h: int # height + conf: float #probablity + + def get_foot_coords(self): + return [self.l + 0.5 * self.w, self.t+self.h] + + @classmethod + def from_deepsort(cls, dstrack: DeepsortTrack): + return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf) + + def to_ltwh(self): + return (int(self.l), int(self.t), int(self.w), int(self.h)) + + @dataclass class Track: """A bit of an haphazardous wrapper around the 'real' tracker to provide @@ -55,22 +78,6 @@ class Track: return np.array([]) -@dataclass -class Detection: - track_id: str - l: int # left - t: int # top - w: int # width - h: int # height - - def get_foot_coords(self): - return [self.l + 0.5 * self.w, self.t+self.h] - - @classmethod - def from_deepsort(cls, dstrack: DeepsortTrack): - return cls(dstrack.track_id, *dstrack.to_ltwh()) - - class Tracker: def __init__(self, config: Namespace, is_running: Event): @@ -93,27 +100,40 @@ class Tracker: # TODO: support removal self.tracks = defaultdict(lambda: Track()) - if self.config.detector == DETECTOR_RESNET: + if self.config.detector == DETECTOR_RETINANET: # weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT # self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2) weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT - self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20) + self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.35) self.model.to(self.device) # Put the model in inference mode self.model.eval() # Get the transforms for the model's weights self.preprocess = weights.transforms().to(self.device) + self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9, + # embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth" + ) + elif self.config.detector == DETECTOR_MASKRCNN: + weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1 + self.model = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7) + self.model.to(self.device) + # Put the model in inference mode + self.model.eval() + # Get the transforms for the model's weights + self.preprocess = weights.transforms().to(self.device) + self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9, + # embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth" + ) elif self.config.detector == DETECTOR_YOLOv8: self.model = YOLO('EXPERIMENTS/yolov8x.pt') else: - raise RuntimeError("No valid detector specified. See --help") + raise RuntimeError(f"{self.config.detector} is not implemented yet. See --help") # homography = list(source.glob('*img2world.txt'))[0] self.H = np.loadtxt(self.config.homography, delimiter=',') - self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9) logger.debug("Set up tracker") @@ -136,17 +156,28 @@ class Tracker: # following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE) - frame_i = 0 + prev_frame_i = -1 + while self.is_running.is_set(): - this_run_time = time.time() - # logger.debug(f'test {prev_run_time - this_run_time}') - time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT)) - prev_run_time = time.time() + # this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it + # skips exactly 1 frame on a 10 fps video (which, it obviously should not do) + # so for now, timing should move to emitter + # this_run_time = time.time() + # # logger.debug(f'test {prev_run_time - this_run_time}') + # time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT)) + # prev_run_time = time.time() + start_time = time.time() msg = self.frame_sock.recv() frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s + + if frame.index > (prev_frame_i+1): + logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})") + + + prev_frame_i = frame.index + # logger.info(f"Frame delivery delay = {time.time()-frame.time}s") - start_time = time.time() if self.config.detector == DETECTOR_YOLOv8: @@ -176,18 +207,21 @@ class Tracker: coords = track.get_projected_history(self.H) # get full history trajectories[tid] = { "id": tid, + "det_conf": detection.conf, + "bbox": detection.to_ltwh(), "history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test } # logger.info(f"{trajectories}") frame.trajectories = trajectories - current_time = time.time() - logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)") if self.config.bypass_prediction: self.trajectory_socket.send_string(json.dumps(trajectories)) else: self.trajectory_socket.send(pickle.dumps(frame)) + current_time = time.time() + logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)") + # self.trajectory_socket.send_string(json.dumps(trajectories)) # provide a {ID: {id: ID, history: [[x,y],[x,y],...]}} # TODO: provide a track object that actually keeps history (unlike tracker) @@ -196,13 +230,14 @@ class Tracker: # fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display if training_csv: training_csv.writerows([{ - 'frame_id': round(frame_i * 10., 1), # not really time + 'frame_id': round(frame.index * 10., 1), # not really time 'track_id': t['id'], - 'x': t['history'][-1]['x'], - 'y': t['history'][-1]['y'], + 'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0], + 'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1], } for t in trajectories.values()]) training_frames += len(trajectories) - frame_i += 1 + # print(time.time() - start_time) + if training_fp: training_fp.close() @@ -226,6 +261,9 @@ class Tracker: def _yolov8_track(self, img) -> [Detection]: results: [YOLOResult] = self.model.track(img, persist=True) + if results[0].boxes is None or results[0].boxes.id is None: + # work around https://github.com/ultralytics/ultralytics/issues/5968 + return [] return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())] def _resnet_track(self, img) -> [Detection]: @@ -233,7 +271,7 @@ class Tracker: tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img) return [Detection.from_deepsort(t) for t in tracks] - def _resnet_detect_persons(self, frame) -> Detections: + def _resnet_detect_persons(self, frame) -> [Detection]: t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C) t = t.permute(2, 0, 1)