Try different trackers. Ultralytics & KeypointRCNN

This commit is contained in:
Ruben van de Ven 2023-10-20 18:49:51 +02:00
parent 2434470cdf
commit cc1e417db4
2 changed files with 95 additions and 23 deletions

View file

@ -2,6 +2,8 @@ import argparse
from pathlib import Path from pathlib import Path
import types import types
from trap.tracker import DETECTORS
from pyparsing import Optional from pyparsing import Optional
class LambdaParser(argparse.ArgumentParser): class LambdaParser(argparse.ArgumentParser):
@ -208,6 +210,10 @@ tracker_parser.add_argument("--save-for-training",
help="Specify the path in which to save", help="Specify the path in which to save",
type=Path, type=Path,
default=None) default=None)
tracker_parser.add_argument("--detector",
help="Specify the detector to use",
type=str,
choices=DETECTORS)
# Renderer # Renderer

View file

@ -1,5 +1,7 @@
from argparse import Namespace from argparse import Namespace
from collections import defaultdict
import csv import csv
from dataclasses import dataclass, field
import json import json
import logging import logging
from multiprocessing import Event from multiprocessing import Event
@ -11,9 +13,12 @@ import torch
import zmq import zmq
import cv2 import cv2
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
from deep_sort_realtime.deepsort_tracker import DeepSort from deep_sort_realtime.deepsort_tracker import DeepSort
from deep_sort_realtime.deep_sort.track import Track from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from ultralytics import YOLO
from ultralytics.engine.results import Results as YOLOResult
from trap.frame_emitter import Frame from trap.frame_emitter import Frame
@ -28,6 +33,33 @@ TARGET_DT = .1
logger = logging.getLogger("trap.tracker") logger = logging.getLogger("trap.tracker")
DETECTOR_RESNET = 'resnet'
DETECTOR_YOLOv8 = 'ultralytics'
DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8]
@dataclass
class Track:
track_id: str = None
history: [Detection]= field(default_factory=lambda: [])
@dataclass
class Detection:
track_id: str
l: int # left
t: int # top
w: int # width
h: int # height
def get_foot_coords(self):
return [self.l + 0.5 * self.w, self.t+self.h]
@classmethod
def from_deepsort(cls, dstrack: DeepsortTrack):
return cls(dstrack.track_id, *dstrack.to_ltwh())
class Tracker: class Tracker:
def __init__(self, config: Namespace, is_running: Event): def __init__(self, config: Namespace, is_running: Event):
self.config = config self.config = config
@ -46,20 +78,30 @@ class Tracker:
# # TODO: config device # # TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT # TODO: support removal
self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35) self.tracks = defaultdict(lambda: Track())
if self.config.detector == DETECTOR_RESNET:
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20)
self.model.to(self.device) self.model.to(self.device)
# Put the model in inference mode # Put the model in inference mode
self.model.eval() self.model.eval()
# Get the transforms for the model's weights # Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device) self.preprocess = weights.transforms().to(self.device)
elif self.config.detector == DETECTOR_YOLOv8:
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
else:
raise RuntimeError("No valid detector specified. See --help")
# homography = list(source.glob('*img2world.txt'))[0] # homography = list(source.glob('*img2world.txt'))[0]
self.H = np.loadtxt(self.config.homography, delimiter=',') self.H = np.loadtxt(self.config.homography, delimiter=',')
self.mot_tracker = DeepSort(max_age=5) self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9)
logger.debug("Set up tracker") logger.debug("Set up tracker")
@ -94,30 +136,46 @@ class Tracker:
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s") # logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
start_time = time.time() start_time = time.time()
detections = self.detect_persons(frame.img)
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
TEMP_boxes = [t.to_ltwh() for t in tracks] if self.config.detector == DETECTOR_YOLOv8:
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes]) detections: [Detection] = self._yolov8_track(frame.img)
if len(TEMP_coords): else :
TEMP_proj_coords = cv2.perspectiveTransform(TEMP_coords,self.H) detections: [Detection] = self._resnet_track(frame.img)
# Store detections into tracklets
for detection in detections:
track = self.tracks[detection.track_id]
track.track_id = detection.track_id # for new tracks
track.history.append(detection)
# if len(track.history) > 30: # retain 90 tracks for 90 frames
# track.history.pop(0)
foot_coordinates = np.array([[t.get_foot_coords()] for t in detections])
if len(foot_coordinates):
projected_coordinates = cv2.perspectiveTransform(foot_coordinates,self.H)
else: else:
TEMP_proj_coords = [] projected_coordinates = []
# print(TEMP_proj_coords) # print(TEMP_proj_coords)
trajectories = {} trajectories = {}
for i, coords in enumerate(TEMP_proj_coords): for detection, coords in zip(detections, projected_coordinates):
tid = tracks[i].track_id tid = str(detection.track_id)
trajectories[tid] = { trajectories[tid] = {
"id": tid, "id": tid,
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test "history": [{"x":c[0], "y":c[1]} for c in coords] if not self.config.bypass_prediction else coords.tolist() # already doubles nested, fine for test
} }
# logger.debug(f"{trajectories}") # logger.info(f"{trajectories}")
frame.trajectories = trajectories frame.trajectories = trajectories
current_time = time.time() current_time = time.time()
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)") logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
if self.config.bypass_prediction:
self.trajectory_socket.send_string(json.dumps(trajectories))
else:
self.trajectory_socket.send(pickle.dumps(frame)) self.trajectory_socket.send(pickle.dumps(frame))
# self.trajectory_socket.send_string(json.dumps(trajectories)) # self.trajectory_socket.send_string(json.dumps(trajectories))
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}} # provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
# TODO: provide a track object that actually keeps history (unlike tracker) # TODO: provide a track object that actually keeps history (unlike tracker)
@ -154,8 +212,16 @@ class Tracker:
logger.info('Stopping') logger.info('Stopping')
def _yolov8_track(self, img) -> [Detection]:
results: [YOLOResult] = self.model.track(img, persist=True)
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
def detect_persons(self, frame) -> Detections: def _resnet_track(self, img) -> [Detection]:
detections = self._resnet_detect_persons(img)
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
return [Detection.from_deepsort(t) for t in tracks]
def _resnet_detect_persons(self, frame) -> Detections:
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C) # change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
t = t.permute(2, 0, 1) t = t.permute(2, 0, 1)