testing the tracker

This commit is contained in:
Ruben van de Ven 2024-04-25 16:31:51 +02:00
parent ba4d2f7909
commit 7c05c060c3
10 changed files with 1016 additions and 217 deletions

55
poetry.lock generated
View file

@ -1817,9 +1817,9 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
{version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
] ]
[[package]] [[package]]
@ -1939,8 +1939,8 @@ files = [
[package.dependencies] [package.dependencies]
numpy = [ numpy = [
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
{version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""},
{version = ">=1.23.2,<2", markers = "python_version == \"3.11\""},
] ]
python-dateutil = ">=2.8.2" python-dateutil = ">=2.8.2"
pytz = ">=2020.1" pytz = ">=2020.1"
@ -1970,6 +1970,21 @@ sql-other = ["SQLAlchemy (>=1.4.36)"]
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.8.0)"] xml = ["lxml (>=4.8.0)"]
[[package]]
name = "pandas_helper_calc"
version = "0.0.1"
description = ""
optional = false
python-versions = "*"
files = []
develop = false
[package.source]
type = "git"
url = "https://github.com/scls19fr/pandas-helper-calc"
reference = "HEAD"
resolved_reference = "22df480f09c0fa96548833f9dee8f9128512641b"
[[package]] [[package]]
name = "pandocfilters" name = "pandocfilters"
version = "1.5.0" version = "1.5.0"
@ -2961,6 +2976,24 @@ all = ["numpy", "pytest", "pytest-cov"]
test = ["pytest", "pytest-cov"] test = ["pytest", "pytest-cov"]
vectorized = ["numpy"] vectorized = ["numpy"]
[[package]]
name = "simdkalman"
version = "1.0.4"
description = "Kalman filters vectorized as Single Instruction, Multiple Data"
optional = false
python-versions = "*"
files = [
{file = "simdkalman-1.0.4-py2.py3-none-any.whl", hash = "sha256:fc2c6b9e540e0a26b39d087e78623d3c1e8c6677abf5d91111f5d49e328e1668"},
]
[package.dependencies]
numpy = ">=1.9.0"
[package.extras]
dev = ["check-manifest"]
docs = ["sphinx"]
test = ["pylint"]
[[package]] [[package]]
name = "six" name = "six"
version = "1.16.0" version = "1.16.0"
@ -3300,6 +3333,22 @@ tqdm = "^4.65.0"
type = "directory" type = "directory"
url = "../Trajectron-plus-plus" url = "../Trajectron-plus-plus"
[[package]]
name = "tsmoothie"
version = "1.0.5"
description = "A python library for timeseries smoothing and outlier detection in a vectorized way."
optional = false
python-versions = ">=3"
files = [
{file = "tsmoothie-1.0.5-py3-none-any.whl", hash = "sha256:dedf8d8e011562824abe41783bf33e1b9ee1424bc572853bb82408743316a90e"},
{file = "tsmoothie-1.0.5.tar.gz", hash = "sha256:d83fa0ccae32bde7b904d9581ebf137e8eb18629cc3563d7379ca5f92461f6f5"},
]
[package.dependencies]
numpy = "*"
scipy = "*"
simdkalman = "*"
[[package]] [[package]]
name = "types-python-dateutil" name = "types-python-dateutil"
version = "2.8.19.14" version = "2.8.19.14"
@ -3468,4 +3517,4 @@ watchdog = ["watchdog (>=2.3)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10,<3.12," python-versions = "^3.10,<3.12,"
content-hash = "c9d4fe6a1d054a835a689cee011753b900b696aa8a06b81aa7a10afc24a8bc70" content-hash = "66f062f9db921cfa83e576288d09fd9b959780eb189d95765934ae9a6769f200"

View file

@ -29,6 +29,8 @@ ultralytics = "^8.0.200"
ffmpeg-python = "^0.2.0" ffmpeg-python = "^0.2.0"
torchreid = "^0.2.5" torchreid = "^0.2.5"
gdown = "^4.7.1" gdown = "^4.7.1"
pandas-helper-calc = {git = "https://github.com/scls19fr/pandas-helper-calc"}
tsmoothie = "^1.0.5"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

File diff suppressed because one or more lines are too long

574
test_tracker.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -1,5 +1,6 @@
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import IntFlag
from itertools import cycle from itertools import cycle
import logging import logging
from multiprocessing import Event from multiprocessing import Event
@ -12,9 +13,25 @@ import numpy as np
import cv2 import cv2
import zmq import zmq
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from deep_sort_realtime.deep_sort.track import TrackState as DeepsortTrackState
logger = logging.getLogger('trap.frame_emitter') logger = logging.getLogger('trap.frame_emitter')
class DetectionState(IntFlag):
Tentative = 1 # state before n_init (see DeepsortTrack)
Confirmed = 2 # after tentative
Lost = 4 # lost when DeepsortTrack.time_since_update > 0 but not Deleted
@classmethod
def from_deepsort_track(cls, track: DeepsortTrack):
if track.state == DeepsortTrackState.Tentative:
return cls.Tentative
if track.state == DeepsortTrackState.Confirmed:
if track.time_since_update > 0:
return cls.Lost
return cls.Confirmed
raise RuntimeError("Should not run into Deleted entries here")
@dataclass @dataclass
class Detection: class Detection:
@ -24,13 +41,27 @@ class Detection:
w: int # width - image space w: int # width - image space
h: int # height - image space h: int # height - image space
conf: float # object detector probablity conf: float # object detector probablity
state: DetectionState
def get_foot_coords(self): def get_foot_coords(self):
return [self.l + 0.5 * self.w, self.t+self.h] return [self.l + 0.5 * self.w, self.t+self.h]
@classmethod @classmethod
def from_deepsort(cls, dstrack: DeepsortTrack): def from_deepsort(cls, dstrack: DeepsortTrack):
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf) return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf, DetectionState.from_deepsort_track(dstrack))
def get_scaled(self, scale: float = 1):
if scale == 1:
return self
return Detection(
self.track_id,
self.l*scale,
self.t*scale,
self.w*scale,
self.h*scale,
self.conf,
self.state)
def to_ltwh(self): def to_ltwh(self):
return (int(self.l), int(self.t), int(self.w), int(self.h)) return (int(self.l), int(self.t), int(self.w), int(self.h))
@ -39,6 +70,7 @@ class Detection:
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h)) return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
@dataclass @dataclass
class Track: class Track:
"""A bit of an haphazardous wrapper around the 'real' tracker to provide """A bit of an haphazardous wrapper around the 'real' tracker to provide
@ -63,6 +95,7 @@ class Track:
return [{"x":c[0], "y":c[1]} for c in coords] return [{"x":c[0], "y":c[1]} for c in coords]
@dataclass @dataclass
class Frame: class Frame:
index: int index: int
@ -71,6 +104,19 @@ class Frame:
tracks: Optional[dict[str, Track]] = None tracks: Optional[dict[str, Track]] = None
H: Optional[np.array] = None H: Optional[np.array] = None
def aslist(self) -> [dict]:
return { t.track_id:
{
'id': t.track_id,
'history': t.get_projected_history(self.H).tolist(),
'det_conf': t.history[-1].conf,
# 'det_conf': trajectory_data[node.id]['det_conf'],
# 'bbox': trajectory_data[node.id]['bbox'],
# 'history': history.tolist(),
'predictions': t.predictions
} for t in self.tracks.values()
}
class FrameEmitter: class FrameEmitter:
''' '''
Emit frame in a separate threat so they can be throttled, Emit frame in a separate threat so they can be throttled,
@ -95,6 +141,7 @@ class FrameEmitter:
def emit_video(self): def emit_video(self):
i = 0
for video_path in self.video_srcs: for video_path in self.video_srcs:
logger.info(f"Play from '{str(video_path)}'") logger.info(f"Play from '{str(video_path)}'")
video = cv2.VideoCapture(str(video_path)) video = cv2.VideoCapture(str(video_path))
@ -102,8 +149,21 @@ class FrameEmitter:
target_frame_duration = 1./fps target_frame_duration = 1./fps
logger.info(f"Emit frames at {fps} fps") logger.info(f"Emit frames at {fps} fps")
if '-' in video_path.stem:
path_stem = video_path.stem[:video_path.stem.rfind('-')]
else:
path_stem = video_path.stem
path_stem += "-homography"
homography_path = video_path.with_stem(path_stem).with_suffix('.txt')
logger.info(f'check homography file {homography_path}')
if homography_path.exists():
logger.info(f'Found custom homography file! Using {homography_path}')
video_H = np.loadtxt(homography_path, delimiter=',')
else:
video_H = None
prev_time = time.time() prev_time = time.time()
i = 0
while self.is_running.is_set(): while self.is_running.is_set():
ret, img = video.read() ret, img = video.read()
@ -120,7 +180,7 @@ class FrameEmitter:
# hack to mask out area # hack to mask out area
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1) cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
frame = Frame(index=i, img=img) frame = Frame(index=i, img=img, H=video_H)
# TODO: this is very dirty, need to find another way. # TODO: this is very dirty, need to find another way.
# perhaps multiprocessing Array? # perhaps multiprocessing Array?
self.frame_sock.send(pickle.dumps(frame)) self.frame_sock.send(pickle.dumps(frame))

View file

@ -243,7 +243,7 @@ class PredictionServer:
if self.config.predict_training_data: if self.config.predict_training_data:
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state']) input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
else: else:
zmq_ev = self.trajectory_socket.poll(timeout=3) zmq_ev = self.trajectory_socket.poll(timeout=2000)
if not zmq_ev: if not zmq_ev:
# on no data loop so that is_running is checked # on no data loop so that is_running is checked
continue continue
@ -252,7 +252,7 @@ class PredictionServer:
frame: Frame = pickle.loads(data) frame: Frame = pickle.loads(data)
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()} # trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
# trajectory_data = json.loads(data) # trajectory_data = json.loads(data)
logger.debug(f"Receive {frame.index}") # logger.debug(f"Receive {frame.index}")
# class FakeNode: # class FakeNode:
# def __init__(self, node_type: NodeType): # def __init__(self, node_type: NodeType):
@ -276,12 +276,12 @@ class PredictionServer:
ax = derivative_of(vx, 0.1) ax = derivative_of(vx, 0.1)
ay = derivative_of(vy, 0.1) ay = derivative_of(vy, 0.1)
data_dict = {('position', 'x'): x[:], data_dict = {('position', 'x'): x[:], # [-10:-1]
('position', 'y'): y[:], ('position', 'y'): y[:], # [-10:-1]
('velocity', 'x'): vx[:], ('velocity', 'x'): vx[:], # [-10:-1]
('velocity', 'y'): vy[:], ('velocity', 'y'): vy[:], # [-10:-1]
('acceleration', 'x'): ax[:], ('acceleration', 'x'): ax[:], # [-10:-1]
('acceleration', 'y'): ay[:]} ('acceleration', 'y'): ay[:]} # [-10:-1]
data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']]) data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
node_data = pd.DataFrame(data_dict, columns=data_columns) node_data = pd.DataFrame(data_dict, columns=data_columns)
@ -301,7 +301,7 @@ class PredictionServer:
# TODO: we want to send out empty result... # TODO: we want to send out empty result...
# And want to update the network # And want to update the network
data = json.dumps({}) # data = json.dumps({})
self.prediction_socket.send_pyobj(frame) self.prediction_socket.send_pyobj(frame)
continue continue
@ -325,7 +325,7 @@ class PredictionServer:
warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py
dists, preds = trajectron.incremental_forward(input_dict, dists, preds = trajectron.incremental_forward(input_dict,
maps, maps,
prediction_horizon=25, # TODO: make variable prediction_horizon=125, # TODO: make variable
num_samples=5, # TODO: make variable num_samples=5, # TODO: make variable
robot_present_and_future=robot_present_and_future, robot_present_and_future=robot_present_and_future,
full_dist=True) full_dist=True)

View file

@ -1,3 +1,4 @@
import time
import ffmpeg import ffmpeg
from argparse import Namespace from argparse import Namespace
import datetime import datetime
@ -8,7 +9,7 @@ import numpy as np
import zmq import zmq
from trap.frame_emitter import Frame from trap.frame_emitter import DetectionState, Frame
logger = logging.getLogger("trap.renderer") logger = logging.getLogger("trap.renderer")
@ -84,7 +85,7 @@ class Renderer:
while self.is_running.is_set(): while self.is_running.is_set():
i+=1 i+=1
zmq_ev = self.frame_sock.poll(timeout=3) zmq_ev = self.frame_sock.poll(timeout=2000)
if not zmq_ev: if not zmq_ev:
# when no data comes in, loop so that is_running is checked # when no data comes in, loop so that is_running is checked
continue continue
@ -95,6 +96,32 @@ class Renderer:
except zmq.ZMQError as e: except zmq.ZMQError as e:
logger.debug(f'reuse prediction') logger.debug(f'reuse prediction')
if first_time is None:
first_time = frame.time
decorate_frame(frame, prediction_frame, first_time)
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
# cv2.imwrite(str(img_path), img)
logger.debug(f"write frame {frame.time - first_time:.3f}s")
if self.out_writer:
self.out_writer.write(img)
if self.streaming_process:
self.streaming_process.stdin.write(img.tobytes())
logger.info('Stopping')
if i>2:
if self.streaming_process:
self.streaming_process.stdin.close()
if self.out_writer:
self.out_writer.release()
if self.streaming_process:
# oddly wrapped, because both close and release() take time.
self.streaming_process.wait()
def decorate_frame(frame: Frame, prediction_frame: Frame, first_time) -> np.array:
img = frame.img img = frame.img
# all not working: # all not working:
@ -108,27 +135,31 @@ class Renderer:
if not prediction_frame: if not prediction_frame:
cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1) cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
continue # continue
else: else:
inv_H = np.linalg.pinv(prediction_frame.H)
for track_id, track in prediction_frame.tracks.items(): for track_id, track in prediction_frame.tracks.items():
if not len(track.history): if not len(track.history):
continue continue
# coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0] # coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
coords = [d.get_foot_coords() for d in track.history] coords = [d.get_foot_coords() for d in track.history]
confirmations = [d.state == DetectionState.Confirmed for d in track.history]
# logger.warning(f"{coords=}") # logger.warning(f"{coords=}")
for ci in range(1, len(coords)): for ci in range(1, len(coords)):
start = [int(p) for p in coords[ci-1]] start = [int(p) for p in coords[ci-1]]
end = [int(p) for p in coords[ci]] end = [int(p) for p in coords[ci]]
cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA) color = (255,255,255) if confirmations[ci] else (100,100,100)
cv2.line(img, start, end, color, 2, lineType=cv2.LINE_AA)
if not track.predictions or not len(track.predictions): if not track.predictions or not len(track.predictions):
continue continue
for pred_i, pred in enumerate(track.predictions): for pred_i, pred in enumerate(track.predictions):
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0] pred_coords = cv2.perspectiveTransform(np.array([pred]), inv_H)[0]
color = (0,0,255) if pred_i == 1 else (100,100,100) color = (0,0,255) if pred_i else (100,100,100)
for ci in range(1, len(pred_coords)): for ci in range(1, len(pred_coords)):
start = [int(p) for p in pred_coords[ci-1]] start = [int(p) for p in pred_coords[ci-1]]
end = [int(p) for p in pred_coords[ci]] end = [int(p) for p in pred_coords[ci]]
@ -148,36 +179,20 @@ class Renderer:
cv2.rectangle(img, p1, p2, (255,0,0), 1) cv2.rectangle(img, p1, p2, (255,0,0), 1)
cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA) cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
if first_time is None:
first_time = frame.time
cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1) cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1) cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
if prediction_frame: if prediction_frame:
# render Δt and Δ frames
cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1) cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"{prediction_frame.time - time.time():.2f}s", (200,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"{len(prediction_frame.tracks)} tracks", (500,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()])}", (580, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()])}", (660, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()])}", (740, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
return img
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
# cv2.imwrite(str(img_path), img)
logger.info(f"write frame {frame.time - first_time:.3f}s")
if self.out_writer:
self.out_writer.write(img)
if self.streaming_process:
self.streaming_process.stdin.write(img.tobytes())
logger.info('Stopping')
if i>2:
if self.streaming_process:
self.streaming_process.stdin.close()
if self.out_writer:
self.out_writer.release()
if self.streaming_process:
# oddly wrapped, because both close and release() take time.
self.streaming_process.wait()
def run_renderer(config: Namespace, is_running: Event): def run_renderer(config: Namespace, is_running: Event):

View file

@ -28,7 +28,7 @@ class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler):
self.zmq_socket = zmq_socket self.zmq_socket = zmq_socket
async def on_message(self, message): async def on_message(self, message):
logger.debug(f"recieve msg") logger.debug(f"receive msg")
try: try:
await self.zmq_socket.send_string(message) await self.zmq_socket.send_string(message)
@ -116,6 +116,7 @@ class WsRouter:
context = zmq.asyncio.Context() context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB) self.trajectory_socket = context.socket(zmq.PUB)
logger.info(f'Publish trajectories on {config.zmq_trajectory_addr}')
self.trajectory_socket.bind(config.zmq_trajectory_addr) self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB) self.prediction_socket = context.socket(zmq.SUB)

View file

@ -22,7 +22,7 @@ from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.engine.results import Results as YOLOResult from ultralytics.engine.results import Results as YOLOResult
from trap.frame_emitter import Frame, Detection, Track from trap.frame_emitter import DetectionState, Frame, Detection, Track
# Detection = [int, int, int, int, float, int] # Detection = [int, int, int, int, float, int]
# Detections = [Detection] # Detections = [Detection]
@ -66,6 +66,9 @@ class Tracker:
# TODO: support removal # TODO: support removal
self.tracks = defaultdict(lambda: Track()) self.tracks = defaultdict(lambda: Track())
logger.debug(f"Load tracker: {self.config.detector}")
if self.config.detector == DETECTOR_RETINANET: if self.config.detector == DETECTOR_RETINANET:
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT # weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2) # self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
@ -76,7 +79,7 @@ class Tracker:
self.model.eval() self.model.eval()
# Get the transforms for the model's weights # Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device) self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9, self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=15, nms_max_overlap=0.9,
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth" # embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
) )
elif self.config.detector == DETECTOR_MASKRCNN: elif self.config.detector == DETECTOR_MASKRCNN:
@ -87,7 +90,7 @@ class Tracker:
self.model.eval() self.model.eval()
# Get the transforms for the model's weights # Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device) self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9, self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=15, nms_max_overlap=0.9,
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth" # embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
) )
elif self.config.detector == DETECTOR_YOLOv8: elif self.config.detector == DETECTOR_YOLOv8:
@ -120,7 +123,7 @@ class Tracker:
logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.") logger.warning(f"Path for training-data exists: {self.config.save_for_training}. Continuing assuming that's ok.")
training_fp = open(self.config.save_for_training / 'all.txt', 'w') training_fp = open(self.config.save_for_training / 'all.txt', 'w')
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py # following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE) training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'l', 't', 'w', 'h', 'x', 'y', 'state'], delimiter='\t', quoting=csv.QUOTE_NONE)
prev_frame_i = -1 prev_frame_i = -1
@ -133,6 +136,12 @@ class Tracker:
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT)) # time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
# prev_run_time = time.time() # prev_run_time = time.time()
zmq_ev = self.frame_sock.poll(timeout=2000)
if not zmq_ev:
logger.warn('skip poll after 2000ms')
# when there's no data after timeout, loop so that is_running is checked
continue
start_time = time.time() start_time = time.time()
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
@ -142,6 +151,9 @@ class Tracker:
prev_frame_i = frame.index prev_frame_i = frame.index
# load homography into frame (TODO: should this be done in emitter?) # load homography into frame (TODO: should this be done in emitter?)
if frame.H is None:
# logger.warning('Falling back to default H')
# fallback: load configured H
frame.H = self.H frame.H = self.H
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s") # logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
@ -150,7 +162,7 @@ class Tracker:
if self.config.detector == DETECTOR_YOLOv8: if self.config.detector == DETECTOR_YOLOv8:
detections: [Detection] = self._yolov8_track(frame.img) detections: [Detection] = self._yolov8_track(frame.img)
else : else :
detections: [Detection] = self._resnet_track(frame.img) detections: [Detection] = self._resnet_track(frame.img, scale = 1)
# Store detections into tracklets # Store detections into tracklets
@ -201,10 +213,18 @@ class Tracker:
if training_csv: if training_csv:
training_csv.writerows([{ training_csv.writerows([{
'frame_id': round(frame.index * 10., 1), # not really time 'frame_id': round(frame.index * 10., 1), # not really time
'track_id': t['id'], 'track_id': t.track_id,
'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0], 'l': t.history[-1].l,
'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1], 't': t.history[-1].t,
} for t in active_tracks.values()]) 'w': t.history[-1].w,
'h': t.history[-1].h,
'x': t.get_projected_history(frame.H)[-1][0],
'y': t.get_projected_history(frame.H)[-1][1],
'state': t.history[-1].state.value
# only keep _actual_detections, no lost entries
} for t in active_tracks.values()
# if t.history[-1].state != DetectionState.Lost
])
training_frames += len(active_tracks) training_frames += len(active_tracks)
# print(time.time() - start_time) # print(time.time() - start_time)
@ -236,10 +256,13 @@ class Tracker:
return [] return []
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())] return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
def _resnet_track(self, img) -> [Detection]: def _resnet_track(self, img, scale: float = 1) -> [Detection]:
if scale != 1:
dsize = (int(img.shape[1] * scale), int(img.shape[0] * scale))
img = cv2.resize(img, dsize)
detections = self._resnet_detect_persons(img) detections = self._resnet_detect_persons(img)
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img) tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
return [Detection.from_deepsort(t) for t in tracks] return [Detection.from_deepsort(t).get_scaled(1/scale) for t in tracks]
def _resnet_detect_persons(self, frame) -> [Detection]: def _resnet_detect_persons(self, frame) -> [Detection]:
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

View file

@ -25,12 +25,13 @@
<script> <script>
// minified https://github.com/joewalnes/reconnecting-websocket // minified https://github.com/joewalnes/reconnecting-websocket
!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a}); !function (a, b) { "function" == typeof define && define.amd ? define([], b) : "undefined" != typeof module && module.exports ? module.exports = b() : a.ReconnectingWebSocket = b() }(this, function () { function a(b, c, d) { function l(a, b) { var c = document.createEvent("CustomEvent"); return c.initCustomEvent(a, !1, !1, b), c } var e = { debug: !1, automaticOpen: !0, reconnectInterval: 1e3, maxReconnectInterval: 3e4, reconnectDecay: 1.5, timeoutInterval: 2e3 }; d || (d = {}); for (var f in e) this[f] = "undefined" != typeof d[f] ? d[f] : e[f]; this.url = b, this.reconnectAttempts = 0, this.readyState = WebSocket.CONNECTING, this.protocol = null; var h, g = this, i = !1, j = !1, k = document.createElement("div"); k.addEventListener("open", function (a) { g.onopen(a) }), k.addEventListener("close", function (a) { g.onclose(a) }), k.addEventListener("connecting", function (a) { g.onconnecting(a) }), k.addEventListener("message", function (a) { g.onmessage(a) }), k.addEventListener("error", function (a) { g.onerror(a) }), this.addEventListener = k.addEventListener.bind(k), this.removeEventListener = k.removeEventListener.bind(k), this.dispatchEvent = k.dispatchEvent.bind(k), this.open = function (b) { h = new WebSocket(g.url, c || []), b || k.dispatchEvent(l("connecting")), (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "attempt-connect", g.url); var d = h, e = setTimeout(function () { (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "connection-timeout", g.url), j = !0, d.close(), j = !1 }, g.timeoutInterval); h.onopen = function () { clearTimeout(e), (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onopen", g.url), g.protocol = h.protocol, g.readyState = WebSocket.OPEN, g.reconnectAttempts = 0; var d = l("open"); d.isReconnect = b, b = !1, k.dispatchEvent(d) }, h.onclose = function (c) { if (clearTimeout(e), h = null, i) g.readyState = WebSocket.CLOSED, k.dispatchEvent(l("close")); else { g.readyState = WebSocket.CONNECTING; var d = l("connecting"); d.code = c.code, d.reason = c.reason, d.wasClean = c.wasClean, k.dispatchEvent(d), b || j || ((g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onclose", g.url), k.dispatchEvent(l("close"))); var e = g.reconnectInterval * Math.pow(g.reconnectDecay, g.reconnectAttempts); setTimeout(function () { g.reconnectAttempts++, g.open(!0) }, e > g.maxReconnectInterval ? g.maxReconnectInterval : e) } }, h.onmessage = function (b) { (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onmessage", g.url, b.data); var c = l("message"); c.data = b.data, k.dispatchEvent(c) }, h.onerror = function (b) { (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "onerror", g.url, b), k.dispatchEvent(l("error")) } }, 1 == this.automaticOpen && this.open(!1), this.send = function (b) { if (h) return (g.debug || a.debugAll) && console.debug("ReconnectingWebSocket", "send", g.url, b), h.send(b); throw "INVALID_STATE_ERR : Pausing to reconnect websocket" }, this.close = function (a, b) { "undefined" == typeof a && (a = 1e3), i = !0, h && h.close(a, b) }, this.refresh = function () { h && h.close() } } return a.prototype.onopen = function () { }, a.prototype.onclose = function () { }, a.prototype.onconnecting = function () { }, a.prototype.onmessage = function () { }, a.prototype.onerror = function () { }, a.debugAll = !1, a.CONNECTING = WebSocket.CONNECTING, a.OPEN = WebSocket.OPEN, a.CLOSING = WebSocket.CLOSING, a.CLOSED = WebSocket.CLOSED, a });
</script> </script>
<script> <script>
// map the field to coordinates of our dummy tracker // map the field to coordinates of our dummy tracker
const field_range = { x: [-30, 10], y: [-10, 10] } // see test_homography.ipynb for the logic behind these values
const field_range = { x: [-13.092, 15.37], y: [-4.66, 10.624] }
// Create WebSocket connection. // Create WebSocket connection.
const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`); const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`);
@ -96,17 +97,17 @@
let current_pos = null; let current_pos = null;
function appendAndSendPositions(){ function appendAndSendPositions() {
if(is_moving && current_pos!==null){ if (is_moving && current_pos !== null) {
// throttled update of tracker on movement // throttled update of tracker on movement
tracker[person_counter].addToHistory(current_pos); tracker[person_counter].addToHistory(current_pos);
} }
for(const person_id in tracker){ for (const person_id in tracker) {
if(person_id != person_counter){ // compare int/str if (person_id != person_counter) { // compare int/str
// fade out old tracks // fade out old tracks
tracker[person_id].history.shift() tracker[person_id].history.shift()
if(!tracker[person_id].history.length){ if (!tracker[person_id].history.length) {
delete tracker[person_id] delete tracker[person_id]
} }
} }
@ -125,7 +126,7 @@
const mousePos = getMousePos(fieldEl, event); const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos) const position = mouse_coordinates_to_position(mousePos)
current_pos = position; current_pos = position;
// tracker[person_counter].addToHistory(current_pos); tracker[person_counter].addToHistory(current_pos);
// trajectory_socket.send(JSON.stringify(tracker)) // trajectory_socket.send(JSON.stringify(tracker))
}); });
@ -134,8 +135,8 @@
const mousePos = getMousePos(fieldEl, event); const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos) const position = mouse_coordinates_to_position(mousePos)
current_pos = position; current_pos = position;
// tracker[person_counter].addToHistory(current_pos); tracker[person_counter].addToHistory(current_pos);
// trajectory_socket.send(JSON.stringify(tracker)) trajectory_socket.send(JSON.stringify(tracker))
}); });
document.addEventListener('mouseup', (e) => { document.addEventListener('mouseup', (e) => {
person_counter++; person_counter++;
@ -170,12 +171,13 @@
ctx.stroke(); ctx.stroke();
} }
if(person.hasOwnProperty('predictions') && person.predictions.length > 0) { if (person.hasOwnProperty('predictions') && person.predictions.length > 0) {
// multiple predictions can be sampled // multiple predictions can be sampled
person.predictions.forEach((prediction, i) => { person.predictions.forEach((prediction, i) => {
ctx.beginPath() ctx.beginPath()
ctx.lineWidth = i === 1 ? 3 : 0.2; ctx.lineWidth = 0.2;
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa"; ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
ctx.strokeStyle = "#ccaaaa";
// start from current position: // start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1]))); ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
@ -184,6 +186,33 @@
} }
ctx.stroke(); ctx.stroke();
}); });
// average stroke:
ctx.beginPath()
ctx.lineWidth = 3;
ctx.strokeStyle = "#ff0000";
// start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
for (let index = 0; index < person.predictions[0].length; index++) {
sum = person.predictions.reduce(
(accumulator, prediction) => ({
"x": accumulator.x + prediction[index][0],
"y": accumulator.y + prediction[index][1],
}),
{ x: 0, y: 0 },
);
avg = { x: sum.x / person.predictions.length, y: sum.y / person.predictions.length }
// console.log(sum, avg)
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(avg)))
}
// for (const position of ) {
// }
ctx.stroke();
} }
} }
ctx.restore(); ctx.restore();