tweaking tracker, adding RT-DETR
This commit is contained in:
parent
bd00e4fbd6
commit
c56f6ff3b4
6 changed files with 94 additions and 27 deletions
|
@ -2,10 +2,10 @@
|
|||
# Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack
|
||||
|
||||
tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack']
|
||||
track_high_thresh: 0.0001 # threshold for the first association
|
||||
track_low_thresh: 0.0001 # threshold for the second association
|
||||
new_track_thresh: 0.0001 # threshold for init new track if the detection does not match any tracks
|
||||
track_buffer: 50 # buffer to calculate the time when to remove tracks
|
||||
match_thresh: 0.95 # threshold for matching tracks
|
||||
track_high_thresh: 0.000001 # threshold for the first association
|
||||
track_low_thresh: 0.000001 # threshold for the second association
|
||||
new_track_thresh: 0.000001 # threshold for init new track if the detection does not match any tracks
|
||||
track_buffer: 10 # buffer to calculate the time when to remove tracks
|
||||
match_thresh: 0.99 # threshold for matching tracks
|
||||
fuse_score: True # Whether to fuse confidence scores with the iou distances before matching
|
||||
# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
|
||||
|
|
|
@ -6,8 +6,9 @@ import logging
|
|||
import time
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from multiprocessing.synchronize import Event as BaseEvent
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from charset_normalizer import detect
|
||||
import cv2
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
|
@ -15,6 +16,7 @@ import pyglet
|
|||
import zmq
|
||||
from pyglet import shapes
|
||||
|
||||
from trap.base import Detection
|
||||
from trap.counter import CounterListerner
|
||||
from trap.frame_emitter import Frame, Track
|
||||
from trap.node import Node
|
||||
|
@ -28,6 +30,7 @@ class CvRenderer(Node):
|
|||
def setup(self):
|
||||
self.prediction_sock = self.sub(self.config.zmq_prediction_addr)
|
||||
self.tracker_sock = self.sub(self.config.zmq_trajectory_addr)
|
||||
self.detector_sock = self.sub(self.config.zmq_detection_addr)
|
||||
self.frame_sock = self.sub(self.config.zmq_frame_addr)
|
||||
|
||||
# self.H = self.config.H
|
||||
|
@ -46,6 +49,7 @@ class CvRenderer(Node):
|
|||
self.frame: Frame|None= None
|
||||
self.tracker_frame: Frame|None = None
|
||||
self.prediction_frame: Frame|None = None
|
||||
self.detections: List[Detection]|None = None
|
||||
|
||||
self.tracks: Dict[str, Track] = {}
|
||||
self.predictions: Dict[str, Track] = {}
|
||||
|
@ -159,11 +163,20 @@ class CvRenderer(Node):
|
|||
except zmq.ZMQError as e:
|
||||
logger.debug(f'reuse tracks')
|
||||
|
||||
try:
|
||||
self.detections = self.detector_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
# print('detections')
|
||||
except zmq.ZMQError as e:
|
||||
# print('no detections')
|
||||
# idx = frame.index if frame else "NONE"
|
||||
# logger.debug(f"reuse video frame {idx}")
|
||||
pass
|
||||
|
||||
if first_time is None:
|
||||
first_time = frame.time
|
||||
|
||||
# img = frame.img
|
||||
img = decorate_frame(frame, tracker_frame, prediction_frame, first_time, self.config, self.tracks, self.predictions, self.config.render_clusters)
|
||||
img = decorate_frame(frame, tracker_frame, prediction_frame, first_time, self.config, self.tracks, self.predictions, self.detections, self.config.render_clusters)
|
||||
|
||||
logger.debug(f"write frame {frame.time - first_time:.3f}s")
|
||||
if self.out_writer:
|
||||
|
@ -210,6 +223,12 @@ class CvRenderer(Node):
|
|||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_traj")
|
||||
|
||||
render_parser.add_argument('--zmq-detection-addr',
|
||||
help='Manually specity communication addr for the detection messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_dets")
|
||||
|
||||
render_parser.add_argument('--zmq-prediction-addr',
|
||||
help='Manually specity communication addr for the prediction messages',
|
||||
type=str,
|
||||
|
@ -270,7 +289,7 @@ def get_animation_position(track: Track, current_frame: Frame):
|
|||
|
||||
|
||||
|
||||
def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace, tracks: Dict[str, Track], predictions: Dict[str, Track], as_clusters = True) -> np.array:
|
||||
def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace, tracks: Dict[str, Track], predictions: Dict[str, Track], detections: Optional[List[Detection]], as_clusters = True) -> np.array:
|
||||
scale = 100
|
||||
# TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
|
||||
# or https://github.com/pygobject/pycairo?tab=readme-ov-file
|
||||
|
@ -304,6 +323,19 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame,
|
|||
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
|
||||
cv2.rectangle(img, (0,0), (img.shape[1],25), (0,0,0), -1)
|
||||
|
||||
if detections:
|
||||
for detection in detections:
|
||||
points = [
|
||||
detection.get_foot_coords(),
|
||||
[detection.l, detection.t],
|
||||
[detection.l + detection.w, detection.t + detection.h],
|
||||
]
|
||||
points = frame.camera.points_img_to_world(points, scale)
|
||||
points = [to_point(p) for p in points] # to int
|
||||
|
||||
cv2.rectangle(img, points[1], points[2], (255,255,0), 2)
|
||||
cv2.circle(img, points[0], 5, (255,255,0), 2)
|
||||
|
||||
|
||||
def conversion(points):
|
||||
return convert_world_points_to_img_points(points, scale)
|
||||
|
|
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||
from trap import node
|
||||
from trap.base import *
|
||||
from trap.base import LambdaParser
|
||||
from trap.gemma import ImgMovementFilter
|
||||
from trap.preview_renderer import FrameWriter
|
||||
from trap.video_sources import get_video_source
|
||||
|
||||
|
@ -41,6 +42,7 @@ class FrameEmitter(node.Node):
|
|||
print(self.config.record)
|
||||
writer = FrameWriter(str(self.config.record), None, None) if self.config.record else None
|
||||
try:
|
||||
processor = ImgMovementFilter()
|
||||
while self.run_loop():
|
||||
|
||||
try:
|
||||
|
@ -51,6 +53,8 @@ class FrameEmitter(node.Node):
|
|||
|
||||
frame = Frame(i, img=img, H=self.config.camera.H, camera=self.config.camera)
|
||||
|
||||
# frame.img = processor.apply(frame.img)
|
||||
|
||||
# TODO: this is very dirty, need to find another way.
|
||||
# perhaps multiprocessing Array?
|
||||
self.frame_noimg_sock.send(pickle.dumps(frame.without_img()))
|
||||
|
|
|
@ -37,7 +37,7 @@ Coordinate = Tuple[float, float]
|
|||
DeltaT = float # delta_t in seconds
|
||||
|
||||
OPTION_GROW_ANOMALY_CIRCLE = False
|
||||
OPTION_RENDER_DIFF_SEGMENT = False
|
||||
OPTION_RENDER_DIFF_SEGMENT = True
|
||||
|
||||
class LineGenerator(ABC):
|
||||
@abstractmethod
|
||||
|
@ -706,7 +706,7 @@ class DrawnScenario(TrackScenario):
|
|||
# dt: change speed. Divide to make slower
|
||||
# amp: amplitude of noise
|
||||
# frequency: make smaller to make longer waves
|
||||
noisy_points = apply_perlin_noise_to_line_normal(self.drawn_positions, t/3, .3, .05)
|
||||
noisy_points = apply_perlin_noise_to_line_normal(self.drawn_positions, t/5, .3, .02)
|
||||
drawable_points, alphas = points_fade_out_alpha_mask(noisy_points, track_age, TRACK_FADE_AFTER_DURATION, TRACK_END_FADE)
|
||||
color = SrgbaColor(1.,0.,1.,1.-self.lost_factor())
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from trap.preview_renderer import DrawnTrack
|
|||
import trap.tracker
|
||||
from trap.config import parser
|
||||
from trap.frame_emitter import Camera, Detection, DetectionState, video_src_from_config, Frame
|
||||
from trap.tracker import DETECTOR_YOLOv8, FinalDisplacementFilter, Smoother, TrackReader, _yolov8_track, Track, TrainingDataWriter, Tracker, read_tracks_json
|
||||
from trap.tracker import DETECTOR_YOLOv8, FinalDisplacementFilter, Smoother, TrackReader, _ultralytics_track, Track, TrainingDataWriter, Tracker, read_tracks_json
|
||||
from collections import defaultdict
|
||||
|
||||
import logging
|
||||
|
@ -461,9 +461,12 @@ def draw_track_projected(img: cv2.Mat, track: Track, color_index: int, camera: C
|
|||
for j in range(len(history)-1):
|
||||
# a = history[j]
|
||||
b = history[j+1]
|
||||
detection = track.history[j+1]
|
||||
|
||||
color = point_color if detection.state == DetectionState.Confirmed else (100,100,100)
|
||||
|
||||
# cv2.line(img, to_point(a), to_point(b), point_color, 1)
|
||||
cv2.circle(img, to_point(b), 3, point_color, 2)
|
||||
cv2.circle(img, to_point(b), 3, color, 2)
|
||||
|
||||
|
||||
def draw_track(img: cv2.Mat, track: Track, color_index: int):
|
||||
|
|
|
@ -28,12 +28,14 @@ from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
|
|||
keypointrcnn_resnet50_fpn,
|
||||
maskrcnn_resnet50_fpn_v2)
|
||||
from tsmoothie.smoother import ConvolutionSmoother, KalmanSmoother
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.engine.results import Results as YOLOResult
|
||||
from ultralytics import YOLO, RTDETR
|
||||
from ultralytics.engine.model import Model as UltralyticsModel
|
||||
from ultralytics.engine.results import Results as UltralyticsResult
|
||||
|
||||
from trap import timer
|
||||
from trap.frame_emitter import (Camera, DataclassJSONEncoder, Detection,
|
||||
DetectionState, Frame, Track)
|
||||
from trap.gemma import ImgMovementFilter
|
||||
from trap.node import Node
|
||||
|
||||
# Detection = [int, int, int, int, float, int]
|
||||
|
@ -51,11 +53,12 @@ DETECTOR_RETINANET = 'retinanet'
|
|||
DETECTOR_MASKRCNN = 'maskrcnn'
|
||||
DETECTOR_FASTERRCNN = 'fasterrcnn'
|
||||
DETECTOR_YOLOv8 = 'ultralytics'
|
||||
DETECTOR_RTDETR = 'rtdetr'
|
||||
|
||||
TRACKER_DEEPSORT = 'deepsort'
|
||||
TRACKER_BYTETRACK = 'bytetrack'
|
||||
|
||||
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
||||
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8, DETECTOR_RTDETR]
|
||||
TRACKERS =[TRACKER_DEEPSORT, TRACKER_BYTETRACK]
|
||||
|
||||
TRACKER_CONFIDENCE_MINIMUM = .2
|
||||
|
@ -63,9 +66,9 @@ TRACKER_BYTETRACK_MINIMUM = .1 # bytetrack can track items iwth lower thershold
|
|||
NON_MAXIMUM_SUPRESSION = 1
|
||||
RCNN_SCALE = .4 # seems to have no impact on detections in the corners
|
||||
|
||||
def _yolov8_track(frame: Frame, model: YOLO, **kwargs) -> List[Detection]:
|
||||
def _ultralytics_track(img: cv2.Mat, frame_idx: int, model: UltralyticsModel, **kwargs) -> List[Detection]:
|
||||
|
||||
results: List[YOLOResult] = list(model.track(frame.img, persist=True, tracker="custom_bytetrack.yaml", verbose=False, conf=0.00001, **kwargs))
|
||||
results: List[UltralyticsResult] = list(model.track(img, persist=True, tracker="custom_bytetrack.yaml", verbose=False, conf=0.000001, **kwargs))
|
||||
|
||||
if results[0].boxes is None or results[0].boxes.id is None:
|
||||
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
||||
|
@ -74,7 +77,7 @@ def _yolov8_track(frame: Frame, model: YOLO, **kwargs) -> List[Detection]:
|
|||
boxes = results[0].boxes.xywh.cpu()
|
||||
track_ids = results[0].boxes.id.int().cpu().tolist()
|
||||
classes = results[0].boxes.cls.int().cpu().tolist()
|
||||
return [Detection(track_id, bbox[0]-.5*bbox[2], bbox[1]-.5*bbox[3], bbox[2], bbox[3], 1, DetectionState.Confirmed, frame.index, class_id) for bbox, track_id, class_id in zip(boxes, track_ids, classes)]
|
||||
return [Detection(track_id, bbox[0]-.5*bbox[2], bbox[1]-.5*bbox[3], bbox[2], bbox[3], 1, DetectionState.Confirmed, frame_idx, class_id) for bbox, track_id, class_id in zip(boxes, track_ids, classes)]
|
||||
|
||||
class Multifile():
|
||||
def __init__(self, srcs: List[Path]):
|
||||
|
@ -395,6 +398,8 @@ class Tracker(Node):
|
|||
# # TODO: config device
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.frame_preprocess = ImgMovementFilter()
|
||||
|
||||
# TODO: support removal
|
||||
self.tracks: DefaultDict[str, Track] = defaultdict(lambda: Track())
|
||||
|
||||
|
@ -436,7 +441,15 @@ class Tracker(Node):
|
|||
self.mot_tracker = TrackerWrapper.init_type(self.config.tracker)
|
||||
elif self.config.detector == DETECTOR_YOLOv8:
|
||||
# self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||
self.model = YOLO('yolo11x.pt')
|
||||
# best from arsen:
|
||||
# self.model = YOLO('./tracker/all_yolo11-2-20-15-41/weights')
|
||||
# self.model = YOLO('models/yolo11x-pose.pt')
|
||||
# self.model = YOLO("models/yolo12l.pt")
|
||||
self.model = YOLO("models/yolo12x.pt")
|
||||
# NOTE: changing the model, also tweak imgsz in
|
||||
elif self.config.detector == DETECTOR_RTDETR:
|
||||
# self.model = RTDETR('models/rtdetr-x.pt') # drops frames
|
||||
self.model = RTDETR('models/rtdetr-l.pt') # somewhat less good in corners, but less frame dropping == better tracking
|
||||
else:
|
||||
raise RuntimeError(f"{self.config.detector} is not implemented yet. See --help")
|
||||
|
||||
|
@ -455,14 +468,22 @@ class Tracker(Node):
|
|||
|
||||
self.frame_sock = self.sub(self.config.zmq_frame_addr)
|
||||
self.trajectory_socket = self.pub(self.config.zmq_trajectory_addr)
|
||||
self.detection_socket = self.pub(self.config.zmq_detection_addr)
|
||||
|
||||
logger.debug("Set up tracker")
|
||||
|
||||
def track_frame(self, frame: Frame):
|
||||
if self.config.detector == DETECTOR_YOLOv8:
|
||||
detections: List[Detection] = _yolov8_track(frame, self.model, classes=[0, 15, 16], imgsz=[1152, 640])
|
||||
det_img = frame.img
|
||||
# det_img = self.frame_preprocess.apply(frame.img)
|
||||
|
||||
if self.config.detector in [DETECTOR_YOLOv8, DETECTOR_RTDETR]:
|
||||
# both ultralytics
|
||||
detections: List[Detection] = _ultralytics_track(det_img, frame.index, self.model, classes=[0, 15, 16], imgsz=self.config.imgsz)
|
||||
else :
|
||||
detections: List[Detection] = self._resnet_track(frame, scale = RCNN_SCALE)
|
||||
detections: List[Detection] = self._resnet_track(det_img, frame.index, scale = RCNN_SCALE)
|
||||
|
||||
# emit raw detections
|
||||
self.detection_socket.send_pyobj(detections)
|
||||
|
||||
for detection in detections:
|
||||
track = self.tracks[detection.track_id]
|
||||
|
@ -475,8 +496,7 @@ class Tracker(Node):
|
|||
track.history.append(detection) # add to history
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Live tracking of frames coming in over zmq
|
||||
|
@ -611,13 +631,12 @@ class Tracker(Node):
|
|||
logger.info('Stopping')
|
||||
|
||||
|
||||
def _resnet_track(self, frame: Frame, scale: float = 1) -> List[Detection]:
|
||||
img = frame.img
|
||||
def _resnet_track(self, img: cv2.Mat, frame_idx: int, scale: float = 1) -> List[Detection]:
|
||||
if scale != 1:
|
||||
dsize = (int(img.shape[1] * scale), int(img.shape[0] * scale))
|
||||
img = cv2.resize(img, dsize)
|
||||
detections = self._resnet_detect_persons(img)
|
||||
tracks: List[Detection] = self.mot_tracker.track_detections(detections, img, frame.index)
|
||||
tracks: List[Detection] = self.mot_tracker.track_detections(detections, img, frame_idx)
|
||||
# active_tracks = [t for t in tracks if t.is_confirmed()]
|
||||
return [d.get_scaled(1/scale) for d in tracks]
|
||||
|
||||
|
@ -679,6 +698,11 @@ class Tracker(Node):
|
|||
help='Manually specity communication addr for the trajectory messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_traj")
|
||||
|
||||
argparser.add_argument('--zmq-detection-addr',
|
||||
help='Manually specity communication addr for the detection messages',
|
||||
type=str,
|
||||
default="ipc:///tmp/feeds_dets")
|
||||
|
||||
argparser.add_argument("--save-for-training",
|
||||
help="Specify the path in which to save",
|
||||
|
@ -697,6 +721,10 @@ class Tracker(Node):
|
|||
argparser.add_argument("--smooth-tracks",
|
||||
help="Smooth the tracker tracks before sending them to the predictor",
|
||||
action='store_true')
|
||||
argparser.add_argument("--imgsz",
|
||||
help="Detector imgsz parameter (applicable to ultralytics detectors)",
|
||||
type=int,
|
||||
default=960)
|
||||
return argparser
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue