Compare commits

..

10 commits

Author SHA1 Message Date
Ruben van de Ven
cd8a06a53b list which processes occupy port 2023-12-06 12:08:12 +01:00
Ruben van de Ven
44a618a5ee Refactor: prediction_server produces Frame objects 2023-12-06 12:07:52 +01:00
Ruben van de Ven
a3e42b4501 Try more detector options 2023-12-06 10:25:50 +01:00
Ruben van de Ven
f3b8e031c1 Fix dt 2023-12-06 10:25:01 +01:00
Ruben van de Ven
ec9bb357fd Feature: render to zmq and tag frame with index 2023-12-06 10:24:45 +01:00
Ruben van de Ven
104098d371 Find & create homography for test data 2023-12-06 10:21:43 +01:00
Ruben van de Ven
1aed4161f6 Change default to no-loop 2023-10-27 11:47:36 +02:00
Ruben van de Ven
fd2e8a3b49 Timeout on zmq to allow gracefull exit 2023-10-22 14:25:34 +02:00
Ruben van de Ven
3091557733 Send history of track to predictor 2023-10-22 14:25:01 +02:00
Ruben van de Ven
b911caa5af Fix not-stopping on loop 2023-10-22 14:03:18 +02:00
10 changed files with 1392 additions and 1253 deletions

1813
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -26,6 +26,9 @@ torchvision = [
] ]
deep-sort-realtime = "^1.3.2" deep-sort-realtime = "^1.3.2"
ultralytics = "^8.0.200" ultralytics = "^8.0.200"
ffmpeg-python = "^0.2.0"
torchreid = "^0.2.5"
gdown = "^4.7.1"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]

383
test_homography.ipynb Normal file

File diff suppressed because one or more lines are too long

View file

@ -194,8 +194,8 @@ frame_emitter_parser.add_argument("--video-src",
default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4'))) default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')))
#TODO: camera as source #TODO: camera as source
frame_emitter_parser.add_argument("--video-no-loop", frame_emitter_parser.add_argument("--video-loop",
help="By default it emitter will run indefiniately. This prevents that and plays every video only once.", help="By default it emitter will run only once. This allows it to loop the video file to keep testing.",
action='store_true') action='store_true')
#TODO: camera as source #TODO: camera as source
@ -217,7 +217,17 @@ tracker_parser.add_argument("--detector",
# Renderer # Renderer
render_parser.add_argument("--render-preview", render_parser.add_argument("--render-file",
help="Render a video file previewing the prediction, and its delay compared to the current frame", help="Render a video file previewing the prediction, and its delay compared to the current frame",
action='store_true') action='store_true')
render_parser.add_argument("--render-url",
help="""Stream renderer on given URL. Two easy approaches:
- using zmq wrapper one can specify the LISTENING ip. To listen to any incoming connection: zmq:tcp://0.0.0.0:5556
- alternatively, using e.g. UDP one needs to specify the IP of the client. E.g. udp://100.69.123.91:5556/stream
Note that with ZMQ you can have multiple clients connecting simultaneously. E.g. using `ffplay zmq:tcp://100.109.175.82:5556`
When using udp, connecting can be done using `ffplay udp://100.109.175.82:5556/stream`
""",
type=str,
default=None)

View file

@ -3,21 +3,73 @@ from dataclasses import dataclass, field
from itertools import cycle from itertools import cycle
import logging import logging
from multiprocessing import Event from multiprocessing import Event
from pathlib import Path
import pickle import pickle
import sys import sys
import time import time
from typing import Optional from typing import Iterable, Optional
import numpy as np import numpy as np
import cv2 import cv2
import zmq import zmq
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
logger = logging.getLogger('trap.frame_emitter') logger = logging.getLogger('trap.frame_emitter')
@dataclass
class Detection:
track_id: str # deepsort track id association
l: int # left - image space
t: int # top - image space
w: int # width - image space
h: int # height - image space
conf: float # object detector probablity
def get_foot_coords(self):
return [self.l + 0.5 * self.w, self.t+self.h]
@classmethod
def from_deepsort(cls, dstrack: DeepsortTrack):
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
def to_ltwh(self):
return (int(self.l), int(self.t), int(self.w), int(self.h))
def to_ltrb(self):
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
@dataclass
class Track:
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
a history, with which the predictor can work, as we then can deduce velocity
and acceleration.
"""
track_id: str = None
history: [Detection] = field(default_factory=lambda: [])
predictor_history: Optional[list] = None # in image space
predictions: Optional[list] = None
def get_projected_history(self, H) -> np.array:
foot_coordinates = [d.get_foot_coords() for d in self.history]
if len(foot_coordinates):
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
return coords[0]
return np.array([])
def get_projected_history_as_dict(self, H) -> dict:
coords = self.get_projected_history(H)
return [{"x":c[0], "y":c[1]} for c in coords]
@dataclass @dataclass
class Frame: class Frame:
index: int
img: np.array img: np.array
time: float= field(default_factory=lambda: time.time()) time: float= field(default_factory=lambda: time.time())
trajectories: Optional[dict] = None tracks: Optional[dict[str, Track]] = None
H: Optional[np.array] = None
class FrameEmitter: class FrameEmitter:
''' '''
@ -36,10 +88,10 @@ class FrameEmitter:
logger.info(f"Connection socket {config.zmq_frame_addr}") logger.info(f"Connection socket {config.zmq_frame_addr}")
if self.config.video_no_loop: if self.config.video_loop:
self.video_srcs = self.config.video_src self.video_srcs: Iterable[Path] = cycle(self.config.video_src)
else: else:
self.video_srcs = cycle(self.config.video_src) self.video_srcs: [Path] = self.config.video_src
def emit_video(self): def emit_video(self):
@ -51,6 +103,7 @@ class FrameEmitter:
logger.info(f"Emit frames at {fps} fps") logger.info(f"Emit frames at {fps} fps")
prev_time = time.time() prev_time = time.time()
i = 0
while self.is_running.is_set(): while self.is_running.is_set():
ret, img = video.read() ret, img = video.read()
@ -61,7 +114,13 @@ class FrameEmitter:
# video.set(cv2.CAP_PROP_POS_FRAMES, 0) # video.set(cv2.CAP_PROP_POS_FRAMES, 0)
# ret, img = video.read() # ret, img = video.read()
# assert ret is not False # not really error proof... # assert ret is not False # not really error proof...
frame = Frame(img=img)
if "DATASETS/hof/" in str(video_path):
# hack to mask out area
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
frame = Frame(index=i, img=img)
# TODO: this is very dirty, need to find another way. # TODO: this is very dirty, need to find another way.
# perhaps multiprocessing Array? # perhaps multiprocessing Array?
self.frame_sock.send(pickle.dumps(frame)) self.frame_sock.send(pickle.dumps(frame))
@ -75,10 +134,13 @@ class FrameEmitter:
else: else:
prev_time = new_frame_time prev_time = new_frame_time
i += 1
if not self.is_running.is_set(): if not self.is_running.is_set():
# if not running, also break out of infinite generator loop # if not running, also break out of infinite generator loop
break break
logger.info("Stopping") logger.info("Stopping")

View file

@ -73,7 +73,7 @@ def start():
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'), ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
] ]
if args.render_preview: if args.render_file or args.render_url:
procs.append( procs.append(
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer') ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer')
) )

View file

@ -27,6 +27,7 @@ import matplotlib.pyplot as plt
import zmq import zmq
from trap.frame_emitter import Frame from trap.frame_emitter import Frame
from trap.tracker import Track
logger = logging.getLogger("trap.prediction") logger = logging.getLogger("trap.prediction")
@ -242,33 +243,38 @@ class PredictionServer:
if self.config.predict_training_data: if self.config.predict_training_data:
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state']) input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
else: else:
zmq_ev = self.trajectory_socket.poll(timeout=3)
if not zmq_ev:
# on no data loop so that is_running is checked
continue
data = self.trajectory_socket.recv() data = self.trajectory_socket.recv()
frame: Frame = pickle.loads(data) frame: Frame = pickle.loads(data)
trajectory_data = frame.trajectories # TODO: properly refractor # trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
# trajectory_data = json.loads(data) # trajectory_data = json.loads(data)
logger.debug(f"Receive {trajectory_data}") logger.debug(f"Receive {frame.index}")
# class FakeNode: # class FakeNode:
# def __init__(self, node_type: NodeType): # def __init__(self, node_type: NodeType):
# self.type = node_type # self.type = node_type
input_dict = {} input_dict = {}
for identifier, trajectory in trajectory_data.items(): for identifier, track in frame.tracks.items():
# if len(trajectory['history']) < 7: # if len(trajectory['history']) < 7:
# # TODO: these trajectories should still be in the output, but without predictions # # TODO: these trajectories should still be in the output, but without predictions
# continue # continue
# TODO: modify this into a mapping function between JS data an the expected Node format # TODO: modify this into a mapping function between JS data an the expected Node format
# node = FakeNode(online_env.NodeType.PEDESTRIAN) # node = FakeNode(online_env.NodeType.PEDESTRIAN)
history = [[h['x'], h['y']] for h in trajectory['history']] history = [[h['x'], h['y']] for h in track.get_projected_history_as_dict(frame.H)]
history = np.array(history) history = np.array(history)
x = history[:, 0] x = history[:, 0]
y = history[:, 1] y = history[:, 1]
# TODO: calculate dt based on input # TODO: calculate dt based on input
vx = derivative_of(x, 0.2) #eval_scene.dt vx = derivative_of(x, 0.1) #eval_scene.dt
vy = derivative_of(y, 0.2) vy = derivative_of(y, 0.1)
ax = derivative_of(vx, 0.2) ax = derivative_of(vx, 0.1)
ay = derivative_of(vy, 0.2) ay = derivative_of(vy, 0.1)
data_dict = {('position', 'x'): x[:], data_dict = {('position', 'x'): x[:],
('position', 'y'): y[:], ('position', 'y'): y[:],
@ -296,7 +302,7 @@ class PredictionServer:
# And want to update the network # And want to update the network
data = json.dumps({}) data = json.dumps({})
self.prediction_socket.send_string(data) self.prediction_socket.send_pyobj(frame)
continue continue
@ -320,7 +326,7 @@ class PredictionServer:
dists, preds = trajectron.incremental_forward(input_dict, dists, preds = trajectron.incremental_forward(input_dict,
maps, maps,
prediction_horizon=25, # TODO: make variable prediction_horizon=25, # TODO: make variable
num_samples=20, # TODO: make variable num_samples=5, # TODO: make variable
robot_present_and_future=robot_present_and_future, robot_present_and_future=robot_present_and_future,
full_dist=True) full_dist=True)
end = time.time() end = time.time()
@ -338,6 +344,7 @@ class PredictionServer:
# prediction_dict provides the actual predictions # prediction_dict provides the actual predictions
# histories_dict provides the trajectory used for prediction # histories_dict provides the trajectory used for prediction
# futures_dict is the Ground Truth, which is unvailable in an online setting # futures_dict is the Ground Truth, which is unvailable in an online setting
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds}, prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
eval_scene.dt, eval_scene.dt,
hyperparams['maximum_history_length'], hyperparams['maximum_history_length'],
@ -358,24 +365,29 @@ class PredictionServer:
for node in histories_dict: for node in histories_dict:
history = histories_dict[node] history = histories_dict[node]
# future = futures_dict[node] # future = futures_dict[node] # ground truth dict
predictions = prediction_dict[node] predictions = prediction_dict[node]
if not len(history) or np.isnan(history[-1]).any(): if not len(history) or np.isnan(history[-1]).any():
continue continue
response[node.id] = { # response[node.id] = {
'id': node.id, # 'id': node.id,
'history': history.tolist(), # 'det_conf': trajectory_data[node.id]['det_conf'],
'predictions': predictions[0].tolist() # use batch 0 # 'bbox': trajectory_data[node.id]['bbox'],
} # 'history': history.tolist(),
# 'predictions': predictions[0].tolist() # use batch 0
# }
data = json.dumps(response) frame.tracks[node.id].predictor_history = history.tolist()
frame.tracks[node.id].predictions = predictions[0].tolist() # use batch 0
# data = json.dumps(response)
if self.config.predict_training_data: if self.config.predict_training_data:
logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s") logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s")
else: else:
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)") logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
self.prediction_socket.send_string(data) self.prediction_socket.send_pyobj(frame)
logger.info('Stopping') logger.info('Stopping')

View file

@ -1,4 +1,4 @@
import ffmpeg
from argparse import Namespace from argparse import Namespace
import datetime import datetime
import logging import logging
@ -33,27 +33,65 @@ class Renderer:
self.inv_H = np.linalg.pinv(self.H) self.inv_H = np.linalg.pinv(self.H)
# TODO: get FPS from frame_emitter
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
self.fps = 10
self.frame_size = (1280,720)
self.out_writer = self.start_writer() if self.config.render_file else None
self.streaming_process = self.start_streaming() if self.config.render_url else None
def start_writer(self):
if not self.config.output_dir.exists(): if not self.config.output_dir.exists():
raise FileNotFoundError("Path does not exist") raise FileNotFoundError("Path does not exist")
date_str = datetime.datetime.now().isoformat(timespec="minutes") date_str = datetime.datetime.now().isoformat(timespec="minutes")
filename = self.config.output_dir / f"render_predictions-{date_str}.mp4" filename = self.config.output_dir / f"render_predictions-{date_str}-{self.config.detector}.mp4"
logger.info(f"Write to {filename}") logger.info(f"Write to {filename}")
fourcc = cv2.VideoWriter_fourcc(*'vp09') fourcc = cv2.VideoWriter_fourcc(*'vp09')
# TODO: get FPS from frame_emitter
self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720)) return cv2.VideoWriter(str(filename), fourcc, self.fps, self.frame_size)
def start_streaming(self):
return (
ffmpeg
.input('pipe:', format='rawvideo',codec="rawvideo", pix_fmt='bgr24', s='{}x{}'.format(*self.frame_size))
.output(
self.config.render_url,
#codec = "copy", # use same codecs of the original video
codec='libx264',
listen=1, # enables HTTP server
pix_fmt="yuv420p",
preset="ultrafast",
tune="zerolatency",
g=f"{self.fps*2}",
analyzeduration="2000000",
probesize="1000000",
f='mpegts'
)
.overwrite_output()
.run_async(pipe_stdin=True)
)
# return process
def run(self): def run(self):
predictions = {} prediction_frame = None
i=0 i=0
first_time = None first_time = None
while self.is_running.is_set(): while self.is_running.is_set():
i+=1 i+=1
zmq_ev = self.frame_sock.poll(timeout=3)
if not zmq_ev:
# when no data comes in, loop so that is_running is checked
continue
frame: Frame = self.frame_sock.recv_pyobj() frame: Frame = self.frame_sock.recv_pyobj()
try: try:
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK) prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError as e: except zmq.ZMQError as e:
logger.debug(f'reuse prediction') logger.debug(f'reuse prediction')
@ -68,47 +106,77 @@ class Renderer:
# warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000)) # warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000))
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame) # cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
if not prediction_frame:
cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
continue
else:
for track_id, track in prediction_frame.tracks.items():
if not len(track.history):
continue
for track_id, prediction in predictions.items(): # coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
if not 'history' in prediction or not len(prediction['history']): coords = [d.get_foot_coords() for d in track.history]
continue # logger.warning(f"{coords=}")
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0] for ci in range(1, len(coords)):
# logger.warning(f"{coords=}") start = [int(p) for p in coords[ci-1]]
center = [int(p) for p in coords[-1]] end = [int(p) for p in coords[ci]]
cv2.circle(img, center, 5, (0,255,0)) cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA)
cv2.putText(img, track_id, (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.8, color=(0,255,0))
for ci in range(1, len(coords)): if not track.predictions or not len(track.predictions):
start = [int(p) for p in coords[ci-1]] continue
end = [int(p) for p in coords[ci]]
cv2.line(img, start, end, (255,255,255), 2)
if not 'predictions' in prediction or not len(prediction['predictions']): for pred_i, pred in enumerate(track.predictions):
continue pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
color = (0,0,255) if pred_i == 1 else (100,100,100)
for pred in prediction['predictions']: for ci in range(1, len(pred_coords)):
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0] start = [int(p) for p in pred_coords[ci-1]]
for ci in range(1, len(pred_coords)): end = [int(p) for p in pred_coords[ci]]
start = [int(p) for p in pred_coords[ci-1]] cv2.line(img, start, end, color, 1, lineType=cv2.LINE_AA)
end = [int(p) for p in pred_coords[ci]]
cv2.line(img, start, end, (0,0,255), 1)
for track_id, track in prediction_frame.tracks.items():
# draw tracker marker and track id last so it lies over the trajectories
# this goes is a second loop so it overlays over _all_ trajectories
# coords = cv2.perspectiveTransform(np.array([[track.history[-1].get_foot_coords()]]), self.inv_H)[0]
coords = track.history[-1].get_foot_coords()
center = [int(p) for p in coords]
cv2.circle(img, center, 5, (0,255,0))
(l, t, r, b) = track.history[-1].to_ltrb()
p1 = (l, t)
p2 = (r, b)
cv2.rectangle(img, p1, p2, (255,0,0), 1)
cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
if first_time is None: if first_time is None:
first_time = frame.time first_time = frame.time
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1) cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
if prediction_frame:
cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
img_path = (self.config.output_dir / f"{i:05d}.png").resolve() img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
# cv2.imwrite(str(img_path), img) # cv2.imwrite(str(img_path), img)
self.out.write(img) logger.info(f"write frame {frame.time - first_time:.3f}s")
if self.out_writer:
self.out_writer.write(img)
if self.streaming_process:
self.streaming_process.stdin.write(img.tobytes())
logger.info('Stopping') logger.info('Stopping')
if i>2: if i>2:
self.out.release() if self.streaming_process:
self.streaming_process.stdin.close()
if self.out_writer:
self.out_writer.release()
if self.streaming_process:
# oddly wrapped, because both close and release() take time.
self.streaming_process.wait()

View file

@ -1,8 +1,10 @@
from argparse import Namespace from argparse import Namespace
import asyncio import asyncio
import errno
import logging import logging
from multiprocessing import Event from multiprocessing import Event
import subprocess
from typing import Set, Union, Dict, Any from typing import Set, Union, Dict, Any
from typing_extensions import Self from typing_extensions import Self
@ -145,7 +147,14 @@ class WsRouter:
# loop = tornado.ioloop.IOLoop.current() # loop = tornado.ioloop.IOLoop.current()
logger.info(f"Listen on {self.config.ws_port}") logger.info(f"Listen on {self.config.ws_port}")
self.application.listen(self.config.ws_port) try:
self.application.listen(self.config.ws_port)
except OSError as e:
if e.errno == errno.EADDRINUSE:
logger.critical("Address already in use by process")
subprocess.run(["lsof", "-i", f"tcp:{self.config.ws_port:d}"])
raise
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
task = self.evt_loop.create_task(self.prediction_forwarder()) task = self.evt_loop.create_task(self.prediction_forwarder())
@ -156,9 +165,12 @@ class WsRouter:
async def prediction_forwarder(self): async def prediction_forwarder(self):
logger.info("Starting prediction forwarder") logger.info("Starting prediction forwarder")
while self.is_running.is_set(): while self.is_running.is_set():
msg = await self.prediction_socket.recv_string() # timeout so that if no events occur, loop can still stop on is_running
logger.debug(f"Forward prediction message of {len(msg)} chars") has_event = await self.prediction_socket.poll(timeout=1)
WebSocketPredictionHandler.write_to_clients(msg) if has_event:
msg = await self.prediction_socket.recv_string()
logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg)
# die together: # die together:
self.evt_loop.stop() self.evt_loop.stop()

View file

@ -8,22 +8,24 @@ from multiprocessing import Event
from pathlib import Path from pathlib import Path
import pickle import pickle
import time import time
from typing import Optional
import numpy as np import numpy as np
import torch import torch
import zmq import zmq
import cv2 import cv2
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
from deep_sort_realtime.deepsort_tracker import DeepSort from deep_sort_realtime.deepsort_tracker import DeepSort
from torchvision.models import ResNet50_Weights
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.engine.results import Results as YOLOResult from ultralytics.engine.results import Results as YOLOResult
from trap.frame_emitter import Frame from trap.frame_emitter import Frame, Detection, Track
Detection = [int, int, int, int, float, int] # Detection = [int, int, int, int, float, int]
Detections = [Detection] # Detections = [Detection]
# This is the dt that is also used by the scene. # This is the dt that is also used by the scene.
# as this needs to be rather stable, try to adhere # as this needs to be rather stable, try to adhere
@ -33,30 +35,13 @@ TARGET_DT = .1
logger = logging.getLogger("trap.tracker") logger = logging.getLogger("trap.tracker")
DETECTOR_RESNET = 'resnet' DETECTOR_RETINANET = 'retinanet'
DETECTOR_MASKRCNN = 'maskrcnn'
DETECTOR_FASTERRCNN = 'fasterrcnn'
DETECTOR_YOLOv8 = 'ultralytics' DETECTOR_YOLOv8 = 'ultralytics'
DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8] DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
@dataclass
class Track:
track_id: str = None
history: [Detection]= field(default_factory=lambda: [])
@dataclass
class Detection:
track_id: str
l: int # left
t: int # top
w: int # width
h: int # height
def get_foot_coords(self):
return [self.l + 0.5 * self.w, self.t+self.h]
@classmethod
def from_deepsort(cls, dstrack: DeepsortTrack):
return cls(dstrack.track_id, *dstrack.to_ltwh())
@ -81,27 +66,40 @@ class Tracker:
# TODO: support removal # TODO: support removal
self.tracks = defaultdict(lambda: Track()) self.tracks = defaultdict(lambda: Track())
if self.config.detector == DETECTOR_RESNET: if self.config.detector == DETECTOR_RETINANET:
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT # weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2) # self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20) self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.35)
self.model.to(self.device) self.model.to(self.device)
# Put the model in inference mode # Put the model in inference mode
self.model.eval() self.model.eval()
# Get the transforms for the model's weights # Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device) self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
)
elif self.config.detector == DETECTOR_MASKRCNN:
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
self.model = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)
self.model.to(self.device)
# Put the model in inference mode
self.model.eval()
# Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
)
elif self.config.detector == DETECTOR_YOLOv8: elif self.config.detector == DETECTOR_YOLOv8:
self.model = YOLO('EXPERIMENTS/yolov8x.pt') self.model = YOLO('EXPERIMENTS/yolov8x.pt')
else: else:
raise RuntimeError("No valid detector specified. See --help") raise RuntimeError(f"{self.config.detector} is not implemented yet. See --help")
# homography = list(source.glob('*img2world.txt'))[0] # homography = list(source.glob('*img2world.txt'))[0]
self.H = np.loadtxt(self.config.homography, delimiter=',') self.H = np.loadtxt(self.config.homography, delimiter=',')
self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9)
logger.debug("Set up tracker") logger.debug("Set up tracker")
@ -124,17 +122,29 @@ class Tracker:
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py # following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE) training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
frame_i = 0 prev_frame_i = -1
while self.is_running.is_set():
this_run_time = time.time() while self.is_running.is_set():
# logger.debug(f'test {prev_run_time - this_run_time}') # this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT)) # skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
prev_run_time = time.time() # so for now, timing should move to emitter
# this_run_time = time.time()
# # logger.debug(f'test {prev_run_time - this_run_time}')
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
# prev_run_time = time.time()
msg = self.frame_sock.recv()
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
start_time = time.time() start_time = time.time()
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
if frame.index > (prev_frame_i+1):
logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})")
prev_frame_i = frame.index
# load homography into frame (TODO: should this be done in emitter?)
frame.H = self.H
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
if self.config.detector == DETECTOR_YOLOv8: if self.config.detector == DETECTOR_YOLOv8:
@ -144,37 +154,43 @@ class Tracker:
# Store detections into tracklets # Store detections into tracklets
projected_coordinates = []
for detection in detections: for detection in detections:
track = self.tracks[detection.track_id] track = self.tracks[detection.track_id]
track.track_id = detection.track_id # for new tracks track.track_id = detection.track_id # for new tracks
track.history.append(detection) track.history.append(detection) # add to history
# projected_coordinates.append(track.get_projected_history(self.H)) # then get full history
# TODO: hadle occlusions, and dissappearance
# if len(track.history) > 30: # retain 90 tracks for 90 frames # if len(track.history) > 30: # retain 90 tracks for 90 frames
# track.history.pop(0) # track.history.pop(0)
foot_coordinates = np.array([[t.get_foot_coords()] for t in detections])
if len(foot_coordinates):
projected_coordinates = cv2.perspectiveTransform(foot_coordinates,self.H)
else:
projected_coordinates = []
# print(TEMP_proj_coords) # trajectories = {}
trajectories = {} # for detection in detections:
for detection, coords in zip(detections, projected_coordinates): # tid = str(detection.track_id)
tid = str(detection.track_id) # track = self.tracks[detection.track_id]
trajectories[tid] = { # coords = track.get_projected_history(self.H) # get full history
"id": tid, # trajectories[tid] = {
"history": [{"x":c[0], "y":c[1]} for c in coords] if not self.config.bypass_prediction else coords.tolist() # already doubles nested, fine for test # "id": tid,
} # "det_conf": detection.conf,
# "bbox": detection.to_ltwh(),
# "history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
# }
active_track_ids = [d.track_id for d in detections]
active_tracks = {t.track_id: t for t in self.tracks.values() if t.track_id in active_track_ids}
# logger.info(f"{trajectories}") # logger.info(f"{trajectories}")
frame.trajectories = trajectories frame.tracks = active_tracks
# if self.config.bypass_prediction:
# self.trajectory_socket.send_string(json.dumps(trajectories))
# else:
# self.trajectory_socket.send(pickle.dumps(frame))
self.trajectory_socket.send_pyobj(frame)
current_time = time.time() current_time = time.time()
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)") logger.debug(f"Trajectories: {len(active_tracks)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
if self.config.bypass_prediction:
self.trajectory_socket.send_string(json.dumps(trajectories))
else:
self.trajectory_socket.send(pickle.dumps(frame))
# self.trajectory_socket.send_string(json.dumps(trajectories)) # self.trajectory_socket.send_string(json.dumps(trajectories))
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}} # provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
@ -184,13 +200,14 @@ class Tracker:
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display # fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
if training_csv: if training_csv:
training_csv.writerows([{ training_csv.writerows([{
'frame_id': round(frame_i * 10., 1), # not really time 'frame_id': round(frame.index * 10., 1), # not really time
'track_id': t['id'], 'track_id': t['id'],
'x': t['history'][-1]['x'], 'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0],
'y': t['history'][-1]['y'], 'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1],
} for t in trajectories.values()]) } for t in active_tracks.values()])
training_frames += len(trajectories) training_frames += len(active_tracks)
frame_i += 1 # print(time.time() - start_time)
if training_fp: if training_fp:
training_fp.close() training_fp.close()
@ -214,6 +231,9 @@ class Tracker:
def _yolov8_track(self, img) -> [Detection]: def _yolov8_track(self, img) -> [Detection]:
results: [YOLOResult] = self.model.track(img, persist=True) results: [YOLOResult] = self.model.track(img, persist=True)
if results[0].boxes is None or results[0].boxes.id is None:
# work around https://github.com/ultralytics/ultralytics/issues/5968
return []
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())] return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
def _resnet_track(self, img) -> [Detection]: def _resnet_track(self, img) -> [Detection]:
@ -221,7 +241,7 @@ class Tracker:
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img) tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
return [Detection.from_deepsort(t) for t in tracks] return [Detection.from_deepsort(t) for t in tracks]
def _resnet_detect_persons(self, frame) -> Detections: def _resnet_detect_persons(self, frame) -> [Detection]:
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C) # change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
t = t.permute(2, 0, 1) t = t.permute(2, 0, 1)