Compare commits
10 commits
185962ace5
...
cd8a06a53b
Author | SHA1 | Date | |
---|---|---|---|
|
cd8a06a53b | ||
|
44a618a5ee | ||
|
a3e42b4501 | ||
|
f3b8e031c1 | ||
|
ec9bb357fd | ||
|
104098d371 | ||
|
1aed4161f6 | ||
|
fd2e8a3b49 | ||
|
3091557733 | ||
|
b911caa5af |
10 changed files with 1392 additions and 1253 deletions
1813
poetry.lock
generated
1813
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -26,6 +26,9 @@ torchvision = [
|
|||
]
|
||||
deep-sort-realtime = "^1.3.2"
|
||||
ultralytics = "^8.0.200"
|
||||
ffmpeg-python = "^0.2.0"
|
||||
torchreid = "^0.2.5"
|
||||
gdown = "^4.7.1"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
383
test_homography.ipynb
Normal file
383
test_homography.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -194,8 +194,8 @@ frame_emitter_parser.add_argument("--video-src",
|
|||
default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')))
|
||||
#TODO: camera as source
|
||||
|
||||
frame_emitter_parser.add_argument("--video-no-loop",
|
||||
help="By default it emitter will run indefiniately. This prevents that and plays every video only once.",
|
||||
frame_emitter_parser.add_argument("--video-loop",
|
||||
help="By default it emitter will run only once. This allows it to loop the video file to keep testing.",
|
||||
action='store_true')
|
||||
#TODO: camera as source
|
||||
|
||||
|
@ -217,7 +217,17 @@ tracker_parser.add_argument("--detector",
|
|||
|
||||
# Renderer
|
||||
|
||||
render_parser.add_argument("--render-preview",
|
||||
render_parser.add_argument("--render-file",
|
||||
help="Render a video file previewing the prediction, and its delay compared to the current frame",
|
||||
action='store_true')
|
||||
|
||||
render_parser.add_argument("--render-url",
|
||||
help="""Stream renderer on given URL. Two easy approaches:
|
||||
- using zmq wrapper one can specify the LISTENING ip. To listen to any incoming connection: zmq:tcp://0.0.0.0:5556
|
||||
- alternatively, using e.g. UDP one needs to specify the IP of the client. E.g. udp://100.69.123.91:5556/stream
|
||||
Note that with ZMQ you can have multiple clients connecting simultaneously. E.g. using `ffplay zmq:tcp://100.109.175.82:5556`
|
||||
When using udp, connecting can be done using `ffplay udp://100.109.175.82:5556/stream`
|
||||
""",
|
||||
type=str,
|
||||
default=None)
|
||||
|
||||
|
|
|
@ -3,21 +3,73 @@ from dataclasses import dataclass, field
|
|||
from itertools import cycle
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional
|
||||
import numpy as np
|
||||
import cv2
|
||||
import zmq
|
||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||
|
||||
logger = logging.getLogger('trap.frame_emitter')
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
track_id: str # deepsort track id association
|
||||
l: int # left - image space
|
||||
t: int # top - image space
|
||||
w: int # width - image space
|
||||
h: int # height - image space
|
||||
conf: float # object detector probablity
|
||||
|
||||
def get_foot_coords(self):
|
||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||
|
||||
@classmethod
|
||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
|
||||
|
||||
def to_ltwh(self):
|
||||
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
||||
|
||||
def to_ltrb(self):
|
||||
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Track:
|
||||
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
|
||||
a history, with which the predictor can work, as we then can deduce velocity
|
||||
and acceleration.
|
||||
"""
|
||||
track_id: str = None
|
||||
history: [Detection] = field(default_factory=lambda: [])
|
||||
predictor_history: Optional[list] = None # in image space
|
||||
predictions: Optional[list] = None
|
||||
|
||||
def get_projected_history(self, H) -> np.array:
|
||||
foot_coordinates = [d.get_foot_coords() for d in self.history]
|
||||
|
||||
if len(foot_coordinates):
|
||||
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
|
||||
return coords[0]
|
||||
return np.array([])
|
||||
|
||||
def get_projected_history_as_dict(self, H) -> dict:
|
||||
coords = self.get_projected_history(H)
|
||||
return [{"x":c[0], "y":c[1]} for c in coords]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
index: int
|
||||
img: np.array
|
||||
time: float= field(default_factory=lambda: time.time())
|
||||
trajectories: Optional[dict] = None
|
||||
tracks: Optional[dict[str, Track]] = None
|
||||
H: Optional[np.array] = None
|
||||
|
||||
class FrameEmitter:
|
||||
'''
|
||||
|
@ -36,10 +88,10 @@ class FrameEmitter:
|
|||
|
||||
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
||||
|
||||
if self.config.video_no_loop:
|
||||
self.video_srcs = self.config.video_src
|
||||
if self.config.video_loop:
|
||||
self.video_srcs: Iterable[Path] = cycle(self.config.video_src)
|
||||
else:
|
||||
self.video_srcs = cycle(self.config.video_src)
|
||||
self.video_srcs: [Path] = self.config.video_src
|
||||
|
||||
|
||||
def emit_video(self):
|
||||
|
@ -51,6 +103,7 @@ class FrameEmitter:
|
|||
logger.info(f"Emit frames at {fps} fps")
|
||||
|
||||
prev_time = time.time()
|
||||
i = 0
|
||||
while self.is_running.is_set():
|
||||
ret, img = video.read()
|
||||
|
||||
|
@ -61,7 +114,13 @@ class FrameEmitter:
|
|||
# video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
# ret, img = video.read()
|
||||
# assert ret is not False # not really error proof...
|
||||
frame = Frame(img=img)
|
||||
|
||||
|
||||
if "DATASETS/hof/" in str(video_path):
|
||||
# hack to mask out area
|
||||
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
|
||||
|
||||
frame = Frame(index=i, img=img)
|
||||
# TODO: this is very dirty, need to find another way.
|
||||
# perhaps multiprocessing Array?
|
||||
self.frame_sock.send(pickle.dumps(frame))
|
||||
|
@ -74,11 +133,14 @@ class FrameEmitter:
|
|||
new_frame_time += frame_duration - time_diff
|
||||
else:
|
||||
prev_time = new_frame_time
|
||||
|
||||
i += 1
|
||||
|
||||
if not self.is_running.is_set():
|
||||
# if not running, also break out of infinite generator loop
|
||||
break
|
||||
|
||||
|
||||
logger.info("Stopping")
|
||||
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ def start():
|
|||
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
||||
]
|
||||
|
||||
if args.render_preview:
|
||||
if args.render_file or args.render_url:
|
||||
procs.append(
|
||||
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer')
|
||||
)
|
||||
|
|
|
@ -27,6 +27,7 @@ import matplotlib.pyplot as plt
|
|||
import zmq
|
||||
|
||||
from trap.frame_emitter import Frame
|
||||
from trap.tracker import Track
|
||||
|
||||
logger = logging.getLogger("trap.prediction")
|
||||
|
||||
|
@ -242,33 +243,38 @@ class PredictionServer:
|
|||
if self.config.predict_training_data:
|
||||
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
||||
else:
|
||||
zmq_ev = self.trajectory_socket.poll(timeout=3)
|
||||
if not zmq_ev:
|
||||
# on no data loop so that is_running is checked
|
||||
continue
|
||||
|
||||
data = self.trajectory_socket.recv()
|
||||
frame: Frame = pickle.loads(data)
|
||||
trajectory_data = frame.trajectories # TODO: properly refractor
|
||||
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
|
||||
# trajectory_data = json.loads(data)
|
||||
logger.debug(f"Receive {trajectory_data}")
|
||||
logger.debug(f"Receive {frame.index}")
|
||||
|
||||
# class FakeNode:
|
||||
# def __init__(self, node_type: NodeType):
|
||||
# self.type = node_type
|
||||
|
||||
input_dict = {}
|
||||
for identifier, trajectory in trajectory_data.items():
|
||||
for identifier, track in frame.tracks.items():
|
||||
# if len(trajectory['history']) < 7:
|
||||
# # TODO: these trajectories should still be in the output, but without predictions
|
||||
# continue
|
||||
|
||||
# TODO: modify this into a mapping function between JS data an the expected Node format
|
||||
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
||||
history = [[h['x'], h['y']] for h in trajectory['history']]
|
||||
history = [[h['x'], h['y']] for h in track.get_projected_history_as_dict(frame.H)]
|
||||
history = np.array(history)
|
||||
x = history[:, 0]
|
||||
y = history[:, 1]
|
||||
# TODO: calculate dt based on input
|
||||
vx = derivative_of(x, 0.2) #eval_scene.dt
|
||||
vy = derivative_of(y, 0.2)
|
||||
ax = derivative_of(vx, 0.2)
|
||||
ay = derivative_of(vy, 0.2)
|
||||
vx = derivative_of(x, 0.1) #eval_scene.dt
|
||||
vy = derivative_of(y, 0.1)
|
||||
ax = derivative_of(vx, 0.1)
|
||||
ay = derivative_of(vy, 0.1)
|
||||
|
||||
data_dict = {('position', 'x'): x[:],
|
||||
('position', 'y'): y[:],
|
||||
|
@ -296,7 +302,7 @@ class PredictionServer:
|
|||
# And want to update the network
|
||||
|
||||
data = json.dumps({})
|
||||
self.prediction_socket.send_string(data)
|
||||
self.prediction_socket.send_pyobj(frame)
|
||||
|
||||
continue
|
||||
|
||||
|
@ -320,7 +326,7 @@ class PredictionServer:
|
|||
dists, preds = trajectron.incremental_forward(input_dict,
|
||||
maps,
|
||||
prediction_horizon=25, # TODO: make variable
|
||||
num_samples=20, # TODO: make variable
|
||||
num_samples=5, # TODO: make variable
|
||||
robot_present_and_future=robot_present_and_future,
|
||||
full_dist=True)
|
||||
end = time.time()
|
||||
|
@ -338,6 +344,7 @@ class PredictionServer:
|
|||
# prediction_dict provides the actual predictions
|
||||
# histories_dict provides the trajectory used for prediction
|
||||
# futures_dict is the Ground Truth, which is unvailable in an online setting
|
||||
|
||||
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
|
||||
eval_scene.dt,
|
||||
hyperparams['maximum_history_length'],
|
||||
|
@ -358,24 +365,29 @@ class PredictionServer:
|
|||
|
||||
for node in histories_dict:
|
||||
history = histories_dict[node]
|
||||
# future = futures_dict[node]
|
||||
# future = futures_dict[node] # ground truth dict
|
||||
predictions = prediction_dict[node]
|
||||
|
||||
if not len(history) or np.isnan(history[-1]).any():
|
||||
continue
|
||||
|
||||
response[node.id] = {
|
||||
'id': node.id,
|
||||
'history': history.tolist(),
|
||||
'predictions': predictions[0].tolist() # use batch 0
|
||||
}
|
||||
# response[node.id] = {
|
||||
# 'id': node.id,
|
||||
# 'det_conf': trajectory_data[node.id]['det_conf'],
|
||||
# 'bbox': trajectory_data[node.id]['bbox'],
|
||||
# 'history': history.tolist(),
|
||||
# 'predictions': predictions[0].tolist() # use batch 0
|
||||
# }
|
||||
|
||||
data = json.dumps(response)
|
||||
frame.tracks[node.id].predictor_history = history.tolist()
|
||||
frame.tracks[node.id].predictions = predictions[0].tolist() # use batch 0
|
||||
|
||||
# data = json.dumps(response)
|
||||
if self.config.predict_training_data:
|
||||
logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s")
|
||||
else:
|
||||
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
||||
self.prediction_socket.send_string(data)
|
||||
self.prediction_socket.send_pyobj(frame)
|
||||
logger.info('Stopping')
|
||||
|
||||
|
||||
|
|
134
trap/renderer.py
134
trap/renderer.py
|
@ -1,4 +1,4 @@
|
|||
|
||||
import ffmpeg
|
||||
from argparse import Namespace
|
||||
import datetime
|
||||
import logging
|
||||
|
@ -33,27 +33,65 @@ class Renderer:
|
|||
|
||||
self.inv_H = np.linalg.pinv(self.H)
|
||||
|
||||
# TODO: get FPS from frame_emitter
|
||||
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
||||
self.fps = 10
|
||||
self.frame_size = (1280,720)
|
||||
self.out_writer = self.start_writer() if self.config.render_file else None
|
||||
self.streaming_process = self.start_streaming() if self.config.render_url else None
|
||||
|
||||
def start_writer(self):
|
||||
if not self.config.output_dir.exists():
|
||||
raise FileNotFoundError("Path does not exist")
|
||||
|
||||
date_str = datetime.datetime.now().isoformat(timespec="minutes")
|
||||
filename = self.config.output_dir / f"render_predictions-{date_str}.mp4"
|
||||
filename = self.config.output_dir / f"render_predictions-{date_str}-{self.config.detector}.mp4"
|
||||
logger.info(f"Write to {filename}")
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'vp09')
|
||||
# TODO: get FPS from frame_emitter
|
||||
self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
||||
|
||||
return cv2.VideoWriter(str(filename), fourcc, self.fps, self.frame_size)
|
||||
|
||||
def start_streaming(self):
|
||||
return (
|
||||
ffmpeg
|
||||
.input('pipe:', format='rawvideo',codec="rawvideo", pix_fmt='bgr24', s='{}x{}'.format(*self.frame_size))
|
||||
.output(
|
||||
self.config.render_url,
|
||||
#codec = "copy", # use same codecs of the original video
|
||||
codec='libx264',
|
||||
listen=1, # enables HTTP server
|
||||
pix_fmt="yuv420p",
|
||||
preset="ultrafast",
|
||||
tune="zerolatency",
|
||||
g=f"{self.fps*2}",
|
||||
analyzeduration="2000000",
|
||||
probesize="1000000",
|
||||
f='mpegts'
|
||||
)
|
||||
.overwrite_output()
|
||||
.run_async(pipe_stdin=True)
|
||||
)
|
||||
# return process
|
||||
|
||||
|
||||
|
||||
|
||||
def run(self):
|
||||
predictions = {}
|
||||
prediction_frame = None
|
||||
i=0
|
||||
first_time = None
|
||||
while self.is_running.is_set():
|
||||
i+=1
|
||||
|
||||
zmq_ev = self.frame_sock.poll(timeout=3)
|
||||
if not zmq_ev:
|
||||
# when no data comes in, loop so that is_running is checked
|
||||
continue
|
||||
|
||||
frame: Frame = self.frame_sock.recv_pyobj()
|
||||
try:
|
||||
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
|
||||
prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError as e:
|
||||
logger.debug(f'reuse prediction')
|
||||
|
||||
|
@ -68,47 +106,77 @@ class Renderer:
|
|||
# warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000))
|
||||
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
|
||||
|
||||
|
||||
for track_id, prediction in predictions.items():
|
||||
if not 'history' in prediction or not len(prediction['history']):
|
||||
continue
|
||||
|
||||
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||
# logger.warning(f"{coords=}")
|
||||
center = [int(p) for p in coords[-1]]
|
||||
cv2.circle(img, center, 5, (0,255,0))
|
||||
cv2.putText(img, track_id, (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.8, color=(0,255,0))
|
||||
if not prediction_frame:
|
||||
cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
continue
|
||||
else:
|
||||
for track_id, track in prediction_frame.tracks.items():
|
||||
if not len(track.history):
|
||||
continue
|
||||
|
||||
# coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||
coords = [d.get_foot_coords() for d in track.history]
|
||||
# logger.warning(f"{coords=}")
|
||||
|
||||
for ci in range(1, len(coords)):
|
||||
start = [int(p) for p in coords[ci-1]]
|
||||
end = [int(p) for p in coords[ci]]
|
||||
cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA)
|
||||
|
||||
for ci in range(1, len(coords)):
|
||||
start = [int(p) for p in coords[ci-1]]
|
||||
end = [int(p) for p in coords[ci]]
|
||||
cv2.line(img, start, end, (255,255,255), 2)
|
||||
|
||||
if not 'predictions' in prediction or not len(prediction['predictions']):
|
||||
continue
|
||||
|
||||
for pred in prediction['predictions']:
|
||||
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||
for ci in range(1, len(pred_coords)):
|
||||
start = [int(p) for p in pred_coords[ci-1]]
|
||||
end = [int(p) for p in pred_coords[ci]]
|
||||
cv2.line(img, start, end, (0,0,255), 1)
|
||||
if not track.predictions or not len(track.predictions):
|
||||
continue
|
||||
|
||||
for pred_i, pred in enumerate(track.predictions):
|
||||
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||
color = (0,0,255) if pred_i == 1 else (100,100,100)
|
||||
for ci in range(1, len(pred_coords)):
|
||||
start = [int(p) for p in pred_coords[ci-1]]
|
||||
end = [int(p) for p in pred_coords[ci]]
|
||||
cv2.line(img, start, end, color, 1, lineType=cv2.LINE_AA)
|
||||
|
||||
for track_id, track in prediction_frame.tracks.items():
|
||||
# draw tracker marker and track id last so it lies over the trajectories
|
||||
# this goes is a second loop so it overlays over _all_ trajectories
|
||||
# coords = cv2.perspectiveTransform(np.array([[track.history[-1].get_foot_coords()]]), self.inv_H)[0]
|
||||
coords = track.history[-1].get_foot_coords()
|
||||
|
||||
center = [int(p) for p in coords]
|
||||
cv2.circle(img, center, 5, (0,255,0))
|
||||
(l, t, r, b) = track.history[-1].to_ltrb()
|
||||
p1 = (l, t)
|
||||
p2 = (r, b)
|
||||
cv2.rectangle(img, p1, p2, (255,0,0), 1)
|
||||
cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
|
||||
|
||||
if first_time is None:
|
||||
first_time = frame.time
|
||||
|
||||
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
|
||||
if prediction_frame:
|
||||
cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
|
||||
|
||||
|
||||
|
||||
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
||||
|
||||
# cv2.imwrite(str(img_path), img)
|
||||
self.out.write(img)
|
||||
logger.info(f"write frame {frame.time - first_time:.3f}s")
|
||||
if self.out_writer:
|
||||
self.out_writer.write(img)
|
||||
if self.streaming_process:
|
||||
self.streaming_process.stdin.write(img.tobytes())
|
||||
logger.info('Stopping')
|
||||
|
||||
if i>2:
|
||||
self.out.release()
|
||||
if self.streaming_process:
|
||||
self.streaming_process.stdin.close()
|
||||
if self.out_writer:
|
||||
self.out_writer.release()
|
||||
if self.streaming_process:
|
||||
# oddly wrapped, because both close and release() take time.
|
||||
self.streaming_process.wait()
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
|
||||
from argparse import Namespace
|
||||
import asyncio
|
||||
import errno
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
import subprocess
|
||||
from typing import Set, Union, Dict, Any
|
||||
from typing_extensions import Self
|
||||
|
||||
|
@ -145,7 +147,14 @@ class WsRouter:
|
|||
|
||||
# loop = tornado.ioloop.IOLoop.current()
|
||||
logger.info(f"Listen on {self.config.ws_port}")
|
||||
self.application.listen(self.config.ws_port)
|
||||
try:
|
||||
self.application.listen(self.config.ws_port)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EADDRINUSE:
|
||||
logger.critical("Address already in use by process")
|
||||
subprocess.run(["lsof", "-i", f"tcp:{self.config.ws_port:d}"])
|
||||
raise
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
task = self.evt_loop.create_task(self.prediction_forwarder())
|
||||
|
@ -156,9 +165,12 @@ class WsRouter:
|
|||
async def prediction_forwarder(self):
|
||||
logger.info("Starting prediction forwarder")
|
||||
while self.is_running.is_set():
|
||||
msg = await self.prediction_socket.recv_string()
|
||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||
WebSocketPredictionHandler.write_to_clients(msg)
|
||||
# timeout so that if no events occur, loop can still stop on is_running
|
||||
has_event = await self.prediction_socket.poll(timeout=1)
|
||||
if has_event:
|
||||
msg = await self.prediction_socket.recv_string()
|
||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||
WebSocketPredictionHandler.write_to_clients(msg)
|
||||
|
||||
# die together:
|
||||
self.evt_loop.stop()
|
||||
|
|
152
trap/tracker.py
152
trap/tracker.py
|
@ -8,22 +8,24 @@ from multiprocessing import Event
|
|||
from pathlib import Path
|
||||
import pickle
|
||||
import time
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
import cv2
|
||||
|
||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
|
||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||
from torchvision.models import ResNet50_Weights
|
||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.engine.results import Results as YOLOResult
|
||||
|
||||
from trap.frame_emitter import Frame
|
||||
from trap.frame_emitter import Frame, Detection, Track
|
||||
|
||||
Detection = [int, int, int, int, float, int]
|
||||
Detections = [Detection]
|
||||
# Detection = [int, int, int, int, float, int]
|
||||
# Detections = [Detection]
|
||||
|
||||
# This is the dt that is also used by the scene.
|
||||
# as this needs to be rather stable, try to adhere
|
||||
|
@ -33,32 +35,15 @@ TARGET_DT = .1
|
|||
|
||||
logger = logging.getLogger("trap.tracker")
|
||||
|
||||
DETECTOR_RESNET = 'resnet'
|
||||
DETECTOR_RETINANET = 'retinanet'
|
||||
DETECTOR_MASKRCNN = 'maskrcnn'
|
||||
DETECTOR_FASTERRCNN = 'fasterrcnn'
|
||||
DETECTOR_YOLOv8 = 'ultralytics'
|
||||
|
||||
DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8]
|
||||
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
||||
|
||||
@dataclass
|
||||
class Track:
|
||||
track_id: str = None
|
||||
history: [Detection]= field(default_factory=lambda: [])
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
track_id: str
|
||||
l: int # left
|
||||
t: int # top
|
||||
w: int # width
|
||||
h: int # height
|
||||
|
||||
def get_foot_coords(self):
|
||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||
|
||||
@classmethod
|
||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||
return cls(dstrack.track_id, *dstrack.to_ltwh())
|
||||
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, config: Namespace, is_running: Event):
|
||||
|
@ -81,27 +66,40 @@ class Tracker:
|
|||
# TODO: support removal
|
||||
self.tracks = defaultdict(lambda: Track())
|
||||
|
||||
if self.config.detector == DETECTOR_RESNET:
|
||||
if self.config.detector == DETECTOR_RETINANET:
|
||||
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
||||
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
||||
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
||||
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20)
|
||||
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.35)
|
||||
self.model.to(self.device)
|
||||
# Put the model in inference mode
|
||||
self.model.eval()
|
||||
# Get the transforms for the model's weights
|
||||
self.preprocess = weights.transforms().to(self.device)
|
||||
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
||||
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||
)
|
||||
elif self.config.detector == DETECTOR_MASKRCNN:
|
||||
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
|
||||
self.model = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)
|
||||
self.model.to(self.device)
|
||||
# Put the model in inference mode
|
||||
self.model.eval()
|
||||
# Get the transforms for the model's weights
|
||||
self.preprocess = weights.transforms().to(self.device)
|
||||
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
||||
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||
)
|
||||
elif self.config.detector == DETECTOR_YOLOv8:
|
||||
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||
else:
|
||||
raise RuntimeError("No valid detector specified. See --help")
|
||||
raise RuntimeError(f"{self.config.detector} is not implemented yet. See --help")
|
||||
|
||||
|
||||
# homography = list(source.glob('*img2world.txt'))[0]
|
||||
|
||||
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
||||
|
||||
self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9)
|
||||
logger.debug("Set up tracker")
|
||||
|
||||
|
||||
|
@ -124,17 +122,29 @@ class Tracker:
|
|||
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||
|
||||
frame_i = 0
|
||||
prev_frame_i = -1
|
||||
|
||||
while self.is_running.is_set():
|
||||
this_run_time = time.time()
|
||||
# logger.debug(f'test {prev_run_time - this_run_time}')
|
||||
time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
||||
prev_run_time = time.time()
|
||||
# this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
|
||||
# skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
|
||||
# so for now, timing should move to emitter
|
||||
# this_run_time = time.time()
|
||||
# # logger.debug(f'test {prev_run_time - this_run_time}')
|
||||
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
||||
# prev_run_time = time.time()
|
||||
|
||||
msg = self.frame_sock.recv()
|
||||
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
|
||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||
start_time = time.time()
|
||||
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
||||
|
||||
if frame.index > (prev_frame_i+1):
|
||||
logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})")
|
||||
|
||||
|
||||
prev_frame_i = frame.index
|
||||
# load homography into frame (TODO: should this be done in emitter?)
|
||||
frame.H = self.H
|
||||
|
||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||
|
||||
|
||||
if self.config.detector == DETECTOR_YOLOv8:
|
||||
|
@ -144,38 +154,44 @@ class Tracker:
|
|||
|
||||
|
||||
# Store detections into tracklets
|
||||
projected_coordinates = []
|
||||
for detection in detections:
|
||||
track = self.tracks[detection.track_id]
|
||||
track.track_id = detection.track_id # for new tracks
|
||||
|
||||
track.history.append(detection)
|
||||
track.history.append(detection) # add to history
|
||||
# projected_coordinates.append(track.get_projected_history(self.H)) # then get full history
|
||||
|
||||
# TODO: hadle occlusions, and dissappearance
|
||||
# if len(track.history) > 30: # retain 90 tracks for 90 frames
|
||||
# track.history.pop(0)
|
||||
|
||||
foot_coordinates = np.array([[t.get_foot_coords()] for t in detections])
|
||||
if len(foot_coordinates):
|
||||
projected_coordinates = cv2.perspectiveTransform(foot_coordinates,self.H)
|
||||
else:
|
||||
projected_coordinates = []
|
||||
|
||||
# print(TEMP_proj_coords)
|
||||
trajectories = {}
|
||||
for detection, coords in zip(detections, projected_coordinates):
|
||||
tid = str(detection.track_id)
|
||||
trajectories[tid] = {
|
||||
"id": tid,
|
||||
"history": [{"x":c[0], "y":c[1]} for c in coords] if not self.config.bypass_prediction else coords.tolist() # already doubles nested, fine for test
|
||||
}
|
||||
|
||||
# trajectories = {}
|
||||
# for detection in detections:
|
||||
# tid = str(detection.track_id)
|
||||
# track = self.tracks[detection.track_id]
|
||||
# coords = track.get_projected_history(self.H) # get full history
|
||||
# trajectories[tid] = {
|
||||
# "id": tid,
|
||||
# "det_conf": detection.conf,
|
||||
# "bbox": detection.to_ltwh(),
|
||||
# "history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
||||
# }
|
||||
active_track_ids = [d.track_id for d in detections]
|
||||
active_tracks = {t.track_id: t for t in self.tracks.values() if t.track_id in active_track_ids}
|
||||
# logger.info(f"{trajectories}")
|
||||
frame.trajectories = trajectories
|
||||
frame.tracks = active_tracks
|
||||
|
||||
current_time = time.time()
|
||||
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||
if self.config.bypass_prediction:
|
||||
self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
else:
|
||||
self.trajectory_socket.send(pickle.dumps(frame))
|
||||
# if self.config.bypass_prediction:
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# else:
|
||||
# self.trajectory_socket.send(pickle.dumps(frame))
|
||||
self.trajectory_socket.send_pyobj(frame)
|
||||
|
||||
current_time = time.time()
|
||||
logger.debug(f"Trajectories: {len(active_tracks)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||
# TODO: provide a track object that actually keeps history (unlike tracker)
|
||||
|
@ -184,13 +200,14 @@ class Tracker:
|
|||
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
||||
if training_csv:
|
||||
training_csv.writerows([{
|
||||
'frame_id': round(frame_i * 10., 1), # not really time
|
||||
'frame_id': round(frame.index * 10., 1), # not really time
|
||||
'track_id': t['id'],
|
||||
'x': t['history'][-1]['x'],
|
||||
'y': t['history'][-1]['y'],
|
||||
} for t in trajectories.values()])
|
||||
training_frames += len(trajectories)
|
||||
frame_i += 1
|
||||
'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0],
|
||||
'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1],
|
||||
} for t in active_tracks.values()])
|
||||
training_frames += len(active_tracks)
|
||||
# print(time.time() - start_time)
|
||||
|
||||
|
||||
if training_fp:
|
||||
training_fp.close()
|
||||
|
@ -214,6 +231,9 @@ class Tracker:
|
|||
|
||||
def _yolov8_track(self, img) -> [Detection]:
|
||||
results: [YOLOResult] = self.model.track(img, persist=True)
|
||||
if results[0].boxes is None or results[0].boxes.id is None:
|
||||
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
||||
return []
|
||||
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
||||
|
||||
def _resnet_track(self, img) -> [Detection]:
|
||||
|
@ -221,7 +241,7 @@ class Tracker:
|
|||
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
||||
return [Detection.from_deepsort(t) for t in tracks]
|
||||
|
||||
def _resnet_detect_persons(self, frame) -> Detections:
|
||||
def _resnet_detect_persons(self, frame) -> [Detection]:
|
||||
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
||||
t = t.permute(2, 0, 1)
|
||||
|
|
Loading…
Reference in a new issue