Try different trackers. Ultralytics & KeypointRCNN
This commit is contained in:
parent
2434470cdf
commit
cc1e417db4
2 changed files with 95 additions and 23 deletions
|
@ -2,6 +2,8 @@ import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import types
|
import types
|
||||||
|
|
||||||
|
from trap.tracker import DETECTORS
|
||||||
|
|
||||||
from pyparsing import Optional
|
from pyparsing import Optional
|
||||||
|
|
||||||
class LambdaParser(argparse.ArgumentParser):
|
class LambdaParser(argparse.ArgumentParser):
|
||||||
|
@ -208,6 +210,10 @@ tracker_parser.add_argument("--save-for-training",
|
||||||
help="Specify the path in which to save",
|
help="Specify the path in which to save",
|
||||||
type=Path,
|
type=Path,
|
||||||
default=None)
|
default=None)
|
||||||
|
tracker_parser.add_argument("--detector",
|
||||||
|
help="Specify the detector to use",
|
||||||
|
type=str,
|
||||||
|
choices=DETECTORS)
|
||||||
|
|
||||||
# Renderer
|
# Renderer
|
||||||
|
|
||||||
|
|
100
trap/tracker.py
100
trap/tracker.py
|
@ -1,5 +1,7 @@
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from collections import defaultdict
|
||||||
import csv
|
import csv
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Event
|
from multiprocessing import Event
|
||||||
|
@ -11,9 +13,12 @@ import torch
|
||||||
import zmq
|
import zmq
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
|
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
||||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||||
from deep_sort_realtime.deep_sort.track import Track
|
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||||
|
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from ultralytics.engine.results import Results as YOLOResult
|
||||||
|
|
||||||
from trap.frame_emitter import Frame
|
from trap.frame_emitter import Frame
|
||||||
|
|
||||||
|
@ -28,6 +33,33 @@ TARGET_DT = .1
|
||||||
|
|
||||||
logger = logging.getLogger("trap.tracker")
|
logger = logging.getLogger("trap.tracker")
|
||||||
|
|
||||||
|
DETECTOR_RESNET = 'resnet'
|
||||||
|
DETECTOR_YOLOv8 = 'ultralytics'
|
||||||
|
|
||||||
|
DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Track:
|
||||||
|
track_id: str = None
|
||||||
|
history: [Detection]= field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Detection:
|
||||||
|
track_id: str
|
||||||
|
l: int # left
|
||||||
|
t: int # top
|
||||||
|
w: int # width
|
||||||
|
h: int # height
|
||||||
|
|
||||||
|
def get_foot_coords(self):
|
||||||
|
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||||
|
return cls(dstrack.track_id, *dstrack.to_ltwh())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Tracker:
|
class Tracker:
|
||||||
def __init__(self, config: Namespace, is_running: Event):
|
def __init__(self, config: Namespace, is_running: Event):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -46,20 +78,30 @@ class Tracker:
|
||||||
# # TODO: config device
|
# # TODO: config device
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
# TODO: support removal
|
||||||
self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35)
|
self.tracks = defaultdict(lambda: Track())
|
||||||
|
|
||||||
|
if self.config.detector == DETECTOR_RESNET:
|
||||||
|
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
||||||
|
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
||||||
|
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
||||||
|
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
# Put the model in inference mode
|
# Put the model in inference mode
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
# Get the transforms for the model's weights
|
# Get the transforms for the model's weights
|
||||||
self.preprocess = weights.transforms().to(self.device)
|
self.preprocess = weights.transforms().to(self.device)
|
||||||
|
elif self.config.detector == DETECTOR_YOLOv8:
|
||||||
|
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||||
|
else:
|
||||||
|
raise RuntimeError("No valid detector specified. See --help")
|
||||||
|
|
||||||
|
|
||||||
# homography = list(source.glob('*img2world.txt'))[0]
|
# homography = list(source.glob('*img2world.txt'))[0]
|
||||||
|
|
||||||
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
||||||
|
|
||||||
self.mot_tracker = DeepSort(max_age=5)
|
self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9)
|
||||||
logger.debug("Set up tracker")
|
logger.debug("Set up tracker")
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,30 +136,46 @@ class Tracker:
|
||||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
detections = self.detect_persons(frame.img)
|
|
||||||
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
|
|
||||||
|
|
||||||
TEMP_boxes = [t.to_ltwh() for t in tracks]
|
if self.config.detector == DETECTOR_YOLOv8:
|
||||||
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
|
detections: [Detection] = self._yolov8_track(frame.img)
|
||||||
if len(TEMP_coords):
|
|
||||||
TEMP_proj_coords = cv2.perspectiveTransform(TEMP_coords,self.H)
|
|
||||||
else :
|
else :
|
||||||
TEMP_proj_coords = []
|
detections: [Detection] = self._resnet_track(frame.img)
|
||||||
|
|
||||||
|
|
||||||
|
# Store detections into tracklets
|
||||||
|
for detection in detections:
|
||||||
|
track = self.tracks[detection.track_id]
|
||||||
|
track.track_id = detection.track_id # for new tracks
|
||||||
|
|
||||||
|
track.history.append(detection)
|
||||||
|
# if len(track.history) > 30: # retain 90 tracks for 90 frames
|
||||||
|
# track.history.pop(0)
|
||||||
|
|
||||||
|
foot_coordinates = np.array([[t.get_foot_coords()] for t in detections])
|
||||||
|
if len(foot_coordinates):
|
||||||
|
projected_coordinates = cv2.perspectiveTransform(foot_coordinates,self.H)
|
||||||
|
else:
|
||||||
|
projected_coordinates = []
|
||||||
|
|
||||||
# print(TEMP_proj_coords)
|
# print(TEMP_proj_coords)
|
||||||
trajectories = {}
|
trajectories = {}
|
||||||
for i, coords in enumerate(TEMP_proj_coords):
|
for detection, coords in zip(detections, projected_coordinates):
|
||||||
tid = tracks[i].track_id
|
tid = str(detection.track_id)
|
||||||
trajectories[tid] = {
|
trajectories[tid] = {
|
||||||
"id": tid,
|
"id": tid,
|
||||||
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
|
"history": [{"x":c[0], "y":c[1]} for c in coords] if not self.config.bypass_prediction else coords.tolist() # already doubles nested, fine for test
|
||||||
}
|
}
|
||||||
# logger.debug(f"{trajectories}")
|
# logger.info(f"{trajectories}")
|
||||||
frame.trajectories = trajectories
|
frame.trajectories = trajectories
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||||
|
if self.config.bypass_prediction:
|
||||||
|
self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
|
else:
|
||||||
self.trajectory_socket.send(pickle.dumps(frame))
|
self.trajectory_socket.send(pickle.dumps(frame))
|
||||||
|
|
||||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||||
# TODO: provide a track object that actually keeps history (unlike tracker)
|
# TODO: provide a track object that actually keeps history (unlike tracker)
|
||||||
|
@ -154,8 +212,16 @@ class Tracker:
|
||||||
|
|
||||||
logger.info('Stopping')
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
def _yolov8_track(self, img) -> [Detection]:
|
||||||
|
results: [YOLOResult] = self.model.track(img, persist=True)
|
||||||
|
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
||||||
|
|
||||||
def detect_persons(self, frame) -> Detections:
|
def _resnet_track(self, img) -> [Detection]:
|
||||||
|
detections = self._resnet_detect_persons(img)
|
||||||
|
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
||||||
|
return [Detection.from_deepsort(t) for t in tracks]
|
||||||
|
|
||||||
|
def _resnet_detect_persons(self, frame) -> Detections:
|
||||||
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
||||||
t = t.permute(2, 0, 1)
|
t = t.permute(2, 0, 1)
|
||||||
|
|
Loading…
Reference in a new issue