Try more detector options
This commit is contained in:
parent
f3b8e031c1
commit
a3e42b4501
1 changed files with 76 additions and 38 deletions
114
trap/tracker.py
114
trap/tracker.py
|
@ -13,8 +13,9 @@ import torch
|
||||||
import zmq
|
import zmq
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
|
||||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||||
|
from torchvision.models import ResNet50_Weights
|
||||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
@ -22,8 +23,8 @@ from ultralytics.engine.results import Results as YOLOResult
|
||||||
|
|
||||||
from trap.frame_emitter import Frame
|
from trap.frame_emitter import Frame
|
||||||
|
|
||||||
Detection = [int, int, int, int, float, int]
|
# Detection = [int, int, int, int, float, int]
|
||||||
Detections = [Detection]
|
# Detections = [Detection]
|
||||||
|
|
||||||
# This is the dt that is also used by the scene.
|
# This is the dt that is also used by the scene.
|
||||||
# as this needs to be rather stable, try to adhere
|
# as this needs to be rather stable, try to adhere
|
||||||
|
@ -33,11 +34,33 @@ TARGET_DT = .1
|
||||||
|
|
||||||
logger = logging.getLogger("trap.tracker")
|
logger = logging.getLogger("trap.tracker")
|
||||||
|
|
||||||
DETECTOR_RESNET = 'resnet'
|
DETECTOR_RETINANET = 'retinanet'
|
||||||
|
DETECTOR_MASKRCNN = 'maskrcnn'
|
||||||
|
DETECTOR_FASTERRCNN = 'fasterrcnn'
|
||||||
DETECTOR_YOLOv8 = 'ultralytics'
|
DETECTOR_YOLOv8 = 'ultralytics'
|
||||||
|
|
||||||
DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8]
|
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Detection:
|
||||||
|
track_id: str
|
||||||
|
l: int # left
|
||||||
|
t: int # top
|
||||||
|
w: int # width
|
||||||
|
h: int # height
|
||||||
|
conf: float #probablity
|
||||||
|
|
||||||
|
def get_foot_coords(self):
|
||||||
|
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||||
|
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
|
||||||
|
|
||||||
|
def to_ltwh(self):
|
||||||
|
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Track:
|
class Track:
|
||||||
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
|
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
|
||||||
|
@ -55,22 +78,6 @@ class Track:
|
||||||
return np.array([])
|
return np.array([])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Detection:
|
|
||||||
track_id: str
|
|
||||||
l: int # left
|
|
||||||
t: int # top
|
|
||||||
w: int # width
|
|
||||||
h: int # height
|
|
||||||
|
|
||||||
def get_foot_coords(self):
|
|
||||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
|
||||||
return cls(dstrack.track_id, *dstrack.to_ltwh())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Tracker:
|
class Tracker:
|
||||||
def __init__(self, config: Namespace, is_running: Event):
|
def __init__(self, config: Namespace, is_running: Event):
|
||||||
|
@ -93,27 +100,40 @@ class Tracker:
|
||||||
# TODO: support removal
|
# TODO: support removal
|
||||||
self.tracks = defaultdict(lambda: Track())
|
self.tracks = defaultdict(lambda: Track())
|
||||||
|
|
||||||
if self.config.detector == DETECTOR_RESNET:
|
if self.config.detector == DETECTOR_RETINANET:
|
||||||
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
||||||
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
||||||
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
||||||
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20)
|
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.35)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
# Put the model in inference mode
|
# Put the model in inference mode
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
# Get the transforms for the model's weights
|
# Get the transforms for the model's weights
|
||||||
self.preprocess = weights.transforms().to(self.device)
|
self.preprocess = weights.transforms().to(self.device)
|
||||||
|
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
||||||
|
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||||
|
)
|
||||||
|
elif self.config.detector == DETECTOR_MASKRCNN:
|
||||||
|
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
|
||||||
|
self.model = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)
|
||||||
|
self.model.to(self.device)
|
||||||
|
# Put the model in inference mode
|
||||||
|
self.model.eval()
|
||||||
|
# Get the transforms for the model's weights
|
||||||
|
self.preprocess = weights.transforms().to(self.device)
|
||||||
|
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
||||||
|
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||||
|
)
|
||||||
elif self.config.detector == DETECTOR_YOLOv8:
|
elif self.config.detector == DETECTOR_YOLOv8:
|
||||||
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("No valid detector specified. See --help")
|
raise RuntimeError(f"{self.config.detector} is not implemented yet. See --help")
|
||||||
|
|
||||||
|
|
||||||
# homography = list(source.glob('*img2world.txt'))[0]
|
# homography = list(source.glob('*img2world.txt'))[0]
|
||||||
|
|
||||||
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
||||||
|
|
||||||
self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9)
|
|
||||||
logger.debug("Set up tracker")
|
logger.debug("Set up tracker")
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,17 +156,28 @@ class Tracker:
|
||||||
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||||
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||||
|
|
||||||
frame_i = 0
|
prev_frame_i = -1
|
||||||
|
|
||||||
while self.is_running.is_set():
|
while self.is_running.is_set():
|
||||||
this_run_time = time.time()
|
# this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
|
||||||
# logger.debug(f'test {prev_run_time - this_run_time}')
|
# skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
|
||||||
time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
# so for now, timing should move to emitter
|
||||||
prev_run_time = time.time()
|
# this_run_time = time.time()
|
||||||
|
# # logger.debug(f'test {prev_run_time - this_run_time}')
|
||||||
|
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
||||||
|
# prev_run_time = time.time()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
msg = self.frame_sock.recv()
|
msg = self.frame_sock.recv()
|
||||||
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
|
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
|
||||||
|
|
||||||
|
if frame.index > (prev_frame_i+1):
|
||||||
|
logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})")
|
||||||
|
|
||||||
|
|
||||||
|
prev_frame_i = frame.index
|
||||||
|
|
||||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
if self.config.detector == DETECTOR_YOLOv8:
|
if self.config.detector == DETECTOR_YOLOv8:
|
||||||
|
@ -176,18 +207,21 @@ class Tracker:
|
||||||
coords = track.get_projected_history(self.H) # get full history
|
coords = track.get_projected_history(self.H) # get full history
|
||||||
trajectories[tid] = {
|
trajectories[tid] = {
|
||||||
"id": tid,
|
"id": tid,
|
||||||
|
"det_conf": detection.conf,
|
||||||
|
"bbox": detection.to_ltwh(),
|
||||||
"history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
"history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
||||||
}
|
}
|
||||||
# logger.info(f"{trajectories}")
|
# logger.info(f"{trajectories}")
|
||||||
frame.trajectories = trajectories
|
frame.trajectories = trajectories
|
||||||
|
|
||||||
current_time = time.time()
|
|
||||||
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
|
||||||
if self.config.bypass_prediction:
|
if self.config.bypass_prediction:
|
||||||
self.trajectory_socket.send_string(json.dumps(trajectories))
|
self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
else:
|
else:
|
||||||
self.trajectory_socket.send(pickle.dumps(frame))
|
self.trajectory_socket.send(pickle.dumps(frame))
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||||
|
|
||||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||||
# TODO: provide a track object that actually keeps history (unlike tracker)
|
# TODO: provide a track object that actually keeps history (unlike tracker)
|
||||||
|
@ -196,13 +230,14 @@ class Tracker:
|
||||||
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
||||||
if training_csv:
|
if training_csv:
|
||||||
training_csv.writerows([{
|
training_csv.writerows([{
|
||||||
'frame_id': round(frame_i * 10., 1), # not really time
|
'frame_id': round(frame.index * 10., 1), # not really time
|
||||||
'track_id': t['id'],
|
'track_id': t['id'],
|
||||||
'x': t['history'][-1]['x'],
|
'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0],
|
||||||
'y': t['history'][-1]['y'],
|
'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1],
|
||||||
} for t in trajectories.values()])
|
} for t in trajectories.values()])
|
||||||
training_frames += len(trajectories)
|
training_frames += len(trajectories)
|
||||||
frame_i += 1
|
# print(time.time() - start_time)
|
||||||
|
|
||||||
|
|
||||||
if training_fp:
|
if training_fp:
|
||||||
training_fp.close()
|
training_fp.close()
|
||||||
|
@ -226,6 +261,9 @@ class Tracker:
|
||||||
|
|
||||||
def _yolov8_track(self, img) -> [Detection]:
|
def _yolov8_track(self, img) -> [Detection]:
|
||||||
results: [YOLOResult] = self.model.track(img, persist=True)
|
results: [YOLOResult] = self.model.track(img, persist=True)
|
||||||
|
if results[0].boxes is None or results[0].boxes.id is None:
|
||||||
|
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
||||||
|
return []
|
||||||
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
||||||
|
|
||||||
def _resnet_track(self, img) -> [Detection]:
|
def _resnet_track(self, img) -> [Detection]:
|
||||||
|
@ -233,7 +271,7 @@ class Tracker:
|
||||||
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
||||||
return [Detection.from_deepsort(t) for t in tracks]
|
return [Detection.from_deepsort(t) for t in tracks]
|
||||||
|
|
||||||
def _resnet_detect_persons(self, frame) -> Detections:
|
def _resnet_detect_persons(self, frame) -> [Detection]:
|
||||||
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
||||||
t = t.permute(2, 0, 1)
|
t = t.permute(2, 0, 1)
|
||||||
|
|
Loading…
Reference in a new issue