Compare commits
No commits in common. "cd8a06a53bc0da129a2fba26a1b56b93d7fa85fb" and "185962ace5c777e39205ae8ff603107f3d381e75" have entirely different histories.
cd8a06a53b
...
185962ace5
10 changed files with 1254 additions and 1393 deletions
1813
poetry.lock
generated
1813
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -26,9 +26,6 @@ torchvision = [
|
|||
]
|
||||
deep-sort-realtime = "^1.3.2"
|
||||
ultralytics = "^8.0.200"
|
||||
ffmpeg-python = "^0.2.0"
|
||||
torchreid = "^0.2.5"
|
||||
gdown = "^4.7.1"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -194,8 +194,8 @@ frame_emitter_parser.add_argument("--video-src",
|
|||
default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')))
|
||||
#TODO: camera as source
|
||||
|
||||
frame_emitter_parser.add_argument("--video-loop",
|
||||
help="By default it emitter will run only once. This allows it to loop the video file to keep testing.",
|
||||
frame_emitter_parser.add_argument("--video-no-loop",
|
||||
help="By default it emitter will run indefiniately. This prevents that and plays every video only once.",
|
||||
action='store_true')
|
||||
#TODO: camera as source
|
||||
|
||||
|
@ -217,17 +217,7 @@ tracker_parser.add_argument("--detector",
|
|||
|
||||
# Renderer
|
||||
|
||||
render_parser.add_argument("--render-file",
|
||||
render_parser.add_argument("--render-preview",
|
||||
help="Render a video file previewing the prediction, and its delay compared to the current frame",
|
||||
action='store_true')
|
||||
|
||||
render_parser.add_argument("--render-url",
|
||||
help="""Stream renderer on given URL. Two easy approaches:
|
||||
- using zmq wrapper one can specify the LISTENING ip. To listen to any incoming connection: zmq:tcp://0.0.0.0:5556
|
||||
- alternatively, using e.g. UDP one needs to specify the IP of the client. E.g. udp://100.69.123.91:5556/stream
|
||||
Note that with ZMQ you can have multiple clients connecting simultaneously. E.g. using `ffplay zmq:tcp://100.109.175.82:5556`
|
||||
When using udp, connecting can be done using `ffplay udp://100.109.175.82:5556/stream`
|
||||
""",
|
||||
type=str,
|
||||
default=None)
|
||||
|
||||
|
|
|
@ -3,73 +3,21 @@ from dataclasses import dataclass, field
|
|||
from itertools import cycle
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, Optional
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import cv2
|
||||
import zmq
|
||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||
|
||||
logger = logging.getLogger('trap.frame_emitter')
|
||||
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
track_id: str # deepsort track id association
|
||||
l: int # left - image space
|
||||
t: int # top - image space
|
||||
w: int # width - image space
|
||||
h: int # height - image space
|
||||
conf: float # object detector probablity
|
||||
|
||||
def get_foot_coords(self):
|
||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||
|
||||
@classmethod
|
||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
|
||||
|
||||
def to_ltwh(self):
|
||||
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
||||
|
||||
def to_ltrb(self):
|
||||
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Track:
|
||||
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
|
||||
a history, with which the predictor can work, as we then can deduce velocity
|
||||
and acceleration.
|
||||
"""
|
||||
track_id: str = None
|
||||
history: [Detection] = field(default_factory=lambda: [])
|
||||
predictor_history: Optional[list] = None # in image space
|
||||
predictions: Optional[list] = None
|
||||
|
||||
def get_projected_history(self, H) -> np.array:
|
||||
foot_coordinates = [d.get_foot_coords() for d in self.history]
|
||||
|
||||
if len(foot_coordinates):
|
||||
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
|
||||
return coords[0]
|
||||
return np.array([])
|
||||
|
||||
def get_projected_history_as_dict(self, H) -> dict:
|
||||
coords = self.get_projected_history(H)
|
||||
return [{"x":c[0], "y":c[1]} for c in coords]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frame:
|
||||
index: int
|
||||
img: np.array
|
||||
time: float= field(default_factory=lambda: time.time())
|
||||
tracks: Optional[dict[str, Track]] = None
|
||||
H: Optional[np.array] = None
|
||||
trajectories: Optional[dict] = None
|
||||
|
||||
class FrameEmitter:
|
||||
'''
|
||||
|
@ -88,10 +36,10 @@ class FrameEmitter:
|
|||
|
||||
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
||||
|
||||
if self.config.video_loop:
|
||||
self.video_srcs: Iterable[Path] = cycle(self.config.video_src)
|
||||
if self.config.video_no_loop:
|
||||
self.video_srcs = self.config.video_src
|
||||
else:
|
||||
self.video_srcs: [Path] = self.config.video_src
|
||||
self.video_srcs = cycle(self.config.video_src)
|
||||
|
||||
|
||||
def emit_video(self):
|
||||
|
@ -103,7 +51,6 @@ class FrameEmitter:
|
|||
logger.info(f"Emit frames at {fps} fps")
|
||||
|
||||
prev_time = time.time()
|
||||
i = 0
|
||||
while self.is_running.is_set():
|
||||
ret, img = video.read()
|
||||
|
||||
|
@ -114,13 +61,7 @@ class FrameEmitter:
|
|||
# video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
# ret, img = video.read()
|
||||
# assert ret is not False # not really error proof...
|
||||
|
||||
|
||||
if "DATASETS/hof/" in str(video_path):
|
||||
# hack to mask out area
|
||||
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
|
||||
|
||||
frame = Frame(index=i, img=img)
|
||||
frame = Frame(img=img)
|
||||
# TODO: this is very dirty, need to find another way.
|
||||
# perhaps multiprocessing Array?
|
||||
self.frame_sock.send(pickle.dumps(frame))
|
||||
|
@ -133,14 +74,11 @@ class FrameEmitter:
|
|||
new_frame_time += frame_duration - time_diff
|
||||
else:
|
||||
prev_time = new_frame_time
|
||||
|
||||
i += 1
|
||||
|
||||
if not self.is_running.is_set():
|
||||
# if not running, also break out of infinite generator loop
|
||||
break
|
||||
|
||||
|
||||
logger.info("Stopping")
|
||||
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ def start():
|
|||
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
||||
]
|
||||
|
||||
if args.render_file or args.render_url:
|
||||
if args.render_preview:
|
||||
procs.append(
|
||||
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer')
|
||||
)
|
||||
|
|
|
@ -27,7 +27,6 @@ import matplotlib.pyplot as plt
|
|||
import zmq
|
||||
|
||||
from trap.frame_emitter import Frame
|
||||
from trap.tracker import Track
|
||||
|
||||
logger = logging.getLogger("trap.prediction")
|
||||
|
||||
|
@ -243,38 +242,33 @@ class PredictionServer:
|
|||
if self.config.predict_training_data:
|
||||
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
||||
else:
|
||||
zmq_ev = self.trajectory_socket.poll(timeout=3)
|
||||
if not zmq_ev:
|
||||
# on no data loop so that is_running is checked
|
||||
continue
|
||||
|
||||
data = self.trajectory_socket.recv()
|
||||
frame: Frame = pickle.loads(data)
|
||||
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
|
||||
trajectory_data = frame.trajectories # TODO: properly refractor
|
||||
# trajectory_data = json.loads(data)
|
||||
logger.debug(f"Receive {frame.index}")
|
||||
logger.debug(f"Receive {trajectory_data}")
|
||||
|
||||
# class FakeNode:
|
||||
# def __init__(self, node_type: NodeType):
|
||||
# self.type = node_type
|
||||
|
||||
input_dict = {}
|
||||
for identifier, track in frame.tracks.items():
|
||||
for identifier, trajectory in trajectory_data.items():
|
||||
# if len(trajectory['history']) < 7:
|
||||
# # TODO: these trajectories should still be in the output, but without predictions
|
||||
# continue
|
||||
|
||||
# TODO: modify this into a mapping function between JS data an the expected Node format
|
||||
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
||||
history = [[h['x'], h['y']] for h in track.get_projected_history_as_dict(frame.H)]
|
||||
history = [[h['x'], h['y']] for h in trajectory['history']]
|
||||
history = np.array(history)
|
||||
x = history[:, 0]
|
||||
y = history[:, 1]
|
||||
# TODO: calculate dt based on input
|
||||
vx = derivative_of(x, 0.1) #eval_scene.dt
|
||||
vy = derivative_of(y, 0.1)
|
||||
ax = derivative_of(vx, 0.1)
|
||||
ay = derivative_of(vy, 0.1)
|
||||
vx = derivative_of(x, 0.2) #eval_scene.dt
|
||||
vy = derivative_of(y, 0.2)
|
||||
ax = derivative_of(vx, 0.2)
|
||||
ay = derivative_of(vy, 0.2)
|
||||
|
||||
data_dict = {('position', 'x'): x[:],
|
||||
('position', 'y'): y[:],
|
||||
|
@ -302,7 +296,7 @@ class PredictionServer:
|
|||
# And want to update the network
|
||||
|
||||
data = json.dumps({})
|
||||
self.prediction_socket.send_pyobj(frame)
|
||||
self.prediction_socket.send_string(data)
|
||||
|
||||
continue
|
||||
|
||||
|
@ -326,7 +320,7 @@ class PredictionServer:
|
|||
dists, preds = trajectron.incremental_forward(input_dict,
|
||||
maps,
|
||||
prediction_horizon=25, # TODO: make variable
|
||||
num_samples=5, # TODO: make variable
|
||||
num_samples=20, # TODO: make variable
|
||||
robot_present_and_future=robot_present_and_future,
|
||||
full_dist=True)
|
||||
end = time.time()
|
||||
|
@ -344,7 +338,6 @@ class PredictionServer:
|
|||
# prediction_dict provides the actual predictions
|
||||
# histories_dict provides the trajectory used for prediction
|
||||
# futures_dict is the Ground Truth, which is unvailable in an online setting
|
||||
|
||||
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
|
||||
eval_scene.dt,
|
||||
hyperparams['maximum_history_length'],
|
||||
|
@ -365,29 +358,24 @@ class PredictionServer:
|
|||
|
||||
for node in histories_dict:
|
||||
history = histories_dict[node]
|
||||
# future = futures_dict[node] # ground truth dict
|
||||
# future = futures_dict[node]
|
||||
predictions = prediction_dict[node]
|
||||
|
||||
if not len(history) or np.isnan(history[-1]).any():
|
||||
continue
|
||||
|
||||
# response[node.id] = {
|
||||
# 'id': node.id,
|
||||
# 'det_conf': trajectory_data[node.id]['det_conf'],
|
||||
# 'bbox': trajectory_data[node.id]['bbox'],
|
||||
# 'history': history.tolist(),
|
||||
# 'predictions': predictions[0].tolist() # use batch 0
|
||||
# }
|
||||
response[node.id] = {
|
||||
'id': node.id,
|
||||
'history': history.tolist(),
|
||||
'predictions': predictions[0].tolist() # use batch 0
|
||||
}
|
||||
|
||||
frame.tracks[node.id].predictor_history = history.tolist()
|
||||
frame.tracks[node.id].predictions = predictions[0].tolist() # use batch 0
|
||||
|
||||
# data = json.dumps(response)
|
||||
data = json.dumps(response)
|
||||
if self.config.predict_training_data:
|
||||
logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s")
|
||||
else:
|
||||
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
||||
self.prediction_socket.send_pyobj(frame)
|
||||
self.prediction_socket.send_string(data)
|
||||
logger.info('Stopping')
|
||||
|
||||
|
||||
|
|
134
trap/renderer.py
134
trap/renderer.py
|
@ -1,4 +1,4 @@
|
|||
import ffmpeg
|
||||
|
||||
from argparse import Namespace
|
||||
import datetime
|
||||
import logging
|
||||
|
@ -33,65 +33,27 @@ class Renderer:
|
|||
|
||||
self.inv_H = np.linalg.pinv(self.H)
|
||||
|
||||
# TODO: get FPS from frame_emitter
|
||||
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
||||
self.fps = 10
|
||||
self.frame_size = (1280,720)
|
||||
self.out_writer = self.start_writer() if self.config.render_file else None
|
||||
self.streaming_process = self.start_streaming() if self.config.render_url else None
|
||||
|
||||
def start_writer(self):
|
||||
if not self.config.output_dir.exists():
|
||||
raise FileNotFoundError("Path does not exist")
|
||||
|
||||
date_str = datetime.datetime.now().isoformat(timespec="minutes")
|
||||
filename = self.config.output_dir / f"render_predictions-{date_str}-{self.config.detector}.mp4"
|
||||
filename = self.config.output_dir / f"render_predictions-{date_str}.mp4"
|
||||
logger.info(f"Write to {filename}")
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'vp09')
|
||||
|
||||
return cv2.VideoWriter(str(filename), fourcc, self.fps, self.frame_size)
|
||||
|
||||
def start_streaming(self):
|
||||
return (
|
||||
ffmpeg
|
||||
.input('pipe:', format='rawvideo',codec="rawvideo", pix_fmt='bgr24', s='{}x{}'.format(*self.frame_size))
|
||||
.output(
|
||||
self.config.render_url,
|
||||
#codec = "copy", # use same codecs of the original video
|
||||
codec='libx264',
|
||||
listen=1, # enables HTTP server
|
||||
pix_fmt="yuv420p",
|
||||
preset="ultrafast",
|
||||
tune="zerolatency",
|
||||
g=f"{self.fps*2}",
|
||||
analyzeduration="2000000",
|
||||
probesize="1000000",
|
||||
f='mpegts'
|
||||
)
|
||||
.overwrite_output()
|
||||
.run_async(pipe_stdin=True)
|
||||
)
|
||||
# return process
|
||||
|
||||
|
||||
# TODO: get FPS from frame_emitter
|
||||
self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
||||
|
||||
|
||||
def run(self):
|
||||
prediction_frame = None
|
||||
predictions = {}
|
||||
i=0
|
||||
first_time = None
|
||||
while self.is_running.is_set():
|
||||
i+=1
|
||||
|
||||
zmq_ev = self.frame_sock.poll(timeout=3)
|
||||
if not zmq_ev:
|
||||
# when no data comes in, loop so that is_running is checked
|
||||
continue
|
||||
|
||||
frame: Frame = self.frame_sock.recv_pyobj()
|
||||
try:
|
||||
prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
|
||||
except zmq.ZMQError as e:
|
||||
logger.debug(f'reuse prediction')
|
||||
|
||||
|
@ -106,77 +68,47 @@ class Renderer:
|
|||
# warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000))
|
||||
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
|
||||
|
||||
if not prediction_frame:
|
||||
cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
continue
|
||||
else:
|
||||
for track_id, track in prediction_frame.tracks.items():
|
||||
if not len(track.history):
|
||||
continue
|
||||
|
||||
# coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||
coords = [d.get_foot_coords() for d in track.history]
|
||||
# logger.warning(f"{coords=}")
|
||||
|
||||
for ci in range(1, len(coords)):
|
||||
start = [int(p) for p in coords[ci-1]]
|
||||
end = [int(p) for p in coords[ci]]
|
||||
cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA)
|
||||
|
||||
for track_id, prediction in predictions.items():
|
||||
if not 'history' in prediction or not len(prediction['history']):
|
||||
continue
|
||||
|
||||
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||
# logger.warning(f"{coords=}")
|
||||
center = [int(p) for p in coords[-1]]
|
||||
cv2.circle(img, center, 5, (0,255,0))
|
||||
cv2.putText(img, track_id, (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.8, color=(0,255,0))
|
||||
|
||||
if not track.predictions or not len(track.predictions):
|
||||
continue
|
||||
for ci in range(1, len(coords)):
|
||||
start = [int(p) for p in coords[ci-1]]
|
||||
end = [int(p) for p in coords[ci]]
|
||||
cv2.line(img, start, end, (255,255,255), 2)
|
||||
|
||||
if not 'predictions' in prediction or not len(prediction['predictions']):
|
||||
continue
|
||||
|
||||
for pred in prediction['predictions']:
|
||||
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||
for ci in range(1, len(pred_coords)):
|
||||
start = [int(p) for p in pred_coords[ci-1]]
|
||||
end = [int(p) for p in pred_coords[ci]]
|
||||
cv2.line(img, start, end, (0,0,255), 1)
|
||||
|
||||
for pred_i, pred in enumerate(track.predictions):
|
||||
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||
color = (0,0,255) if pred_i == 1 else (100,100,100)
|
||||
for ci in range(1, len(pred_coords)):
|
||||
start = [int(p) for p in pred_coords[ci-1]]
|
||||
end = [int(p) for p in pred_coords[ci]]
|
||||
cv2.line(img, start, end, color, 1, lineType=cv2.LINE_AA)
|
||||
|
||||
for track_id, track in prediction_frame.tracks.items():
|
||||
# draw tracker marker and track id last so it lies over the trajectories
|
||||
# this goes is a second loop so it overlays over _all_ trajectories
|
||||
# coords = cv2.perspectiveTransform(np.array([[track.history[-1].get_foot_coords()]]), self.inv_H)[0]
|
||||
coords = track.history[-1].get_foot_coords()
|
||||
|
||||
center = [int(p) for p in coords]
|
||||
cv2.circle(img, center, 5, (0,255,0))
|
||||
(l, t, r, b) = track.history[-1].to_ltrb()
|
||||
p1 = (l, t)
|
||||
p2 = (r, b)
|
||||
cv2.rectangle(img, p1, p2, (255,0,0), 1)
|
||||
cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
|
||||
|
||||
if first_time is None:
|
||||
first_time = frame.time
|
||||
|
||||
cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
|
||||
if prediction_frame:
|
||||
cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
|
||||
|
||||
|
||||
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||
|
||||
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
||||
|
||||
# cv2.imwrite(str(img_path), img)
|
||||
logger.info(f"write frame {frame.time - first_time:.3f}s")
|
||||
if self.out_writer:
|
||||
self.out_writer.write(img)
|
||||
if self.streaming_process:
|
||||
self.streaming_process.stdin.write(img.tobytes())
|
||||
self.out.write(img)
|
||||
logger.info('Stopping')
|
||||
|
||||
if i>2:
|
||||
if self.streaming_process:
|
||||
self.streaming_process.stdin.close()
|
||||
if self.out_writer:
|
||||
self.out_writer.release()
|
||||
if self.streaming_process:
|
||||
# oddly wrapped, because both close and release() take time.
|
||||
self.streaming_process.wait()
|
||||
self.out.release()
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
|
||||
from argparse import Namespace
|
||||
import asyncio
|
||||
import errno
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
import subprocess
|
||||
from typing import Set, Union, Dict, Any
|
||||
from typing_extensions import Self
|
||||
|
||||
|
@ -147,14 +145,7 @@ class WsRouter:
|
|||
|
||||
# loop = tornado.ioloop.IOLoop.current()
|
||||
logger.info(f"Listen on {self.config.ws_port}")
|
||||
try:
|
||||
self.application.listen(self.config.ws_port)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EADDRINUSE:
|
||||
logger.critical("Address already in use by process")
|
||||
subprocess.run(["lsof", "-i", f"tcp:{self.config.ws_port:d}"])
|
||||
raise
|
||||
|
||||
self.application.listen(self.config.ws_port)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
task = self.evt_loop.create_task(self.prediction_forwarder())
|
||||
|
@ -165,12 +156,9 @@ class WsRouter:
|
|||
async def prediction_forwarder(self):
|
||||
logger.info("Starting prediction forwarder")
|
||||
while self.is_running.is_set():
|
||||
# timeout so that if no events occur, loop can still stop on is_running
|
||||
has_event = await self.prediction_socket.poll(timeout=1)
|
||||
if has_event:
|
||||
msg = await self.prediction_socket.recv_string()
|
||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||
WebSocketPredictionHandler.write_to_clients(msg)
|
||||
msg = await self.prediction_socket.recv_string()
|
||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||
WebSocketPredictionHandler.write_to_clients(msg)
|
||||
|
||||
# die together:
|
||||
self.evt_loop.stop()
|
||||
|
|
154
trap/tracker.py
154
trap/tracker.py
|
@ -8,24 +8,22 @@ from multiprocessing import Event
|
|||
from pathlib import Path
|
||||
import pickle
|
||||
import time
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
import cv2
|
||||
|
||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
|
||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||
from torchvision.models import ResNet50_Weights
|
||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.engine.results import Results as YOLOResult
|
||||
|
||||
from trap.frame_emitter import Frame, Detection, Track
|
||||
from trap.frame_emitter import Frame
|
||||
|
||||
# Detection = [int, int, int, int, float, int]
|
||||
# Detections = [Detection]
|
||||
Detection = [int, int, int, int, float, int]
|
||||
Detections = [Detection]
|
||||
|
||||
# This is the dt that is also used by the scene.
|
||||
# as this needs to be rather stable, try to adhere
|
||||
|
@ -35,15 +33,32 @@ TARGET_DT = .1
|
|||
|
||||
logger = logging.getLogger("trap.tracker")
|
||||
|
||||
DETECTOR_RETINANET = 'retinanet'
|
||||
DETECTOR_MASKRCNN = 'maskrcnn'
|
||||
DETECTOR_FASTERRCNN = 'fasterrcnn'
|
||||
DETECTOR_RESNET = 'resnet'
|
||||
DETECTOR_YOLOv8 = 'ultralytics'
|
||||
|
||||
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
||||
DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8]
|
||||
|
||||
@dataclass
|
||||
class Track:
|
||||
track_id: str = None
|
||||
history: [Detection]= field(default_factory=lambda: [])
|
||||
|
||||
@dataclass
|
||||
class Detection:
|
||||
track_id: str
|
||||
l: int # left
|
||||
t: int # top
|
||||
w: int # width
|
||||
h: int # height
|
||||
|
||||
def get_foot_coords(self):
|
||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||
|
||||
@classmethod
|
||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||
return cls(dstrack.track_id, *dstrack.to_ltwh())
|
||||
|
||||
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, config: Namespace, is_running: Event):
|
||||
|
@ -66,40 +81,27 @@ class Tracker:
|
|||
# TODO: support removal
|
||||
self.tracks = defaultdict(lambda: Track())
|
||||
|
||||
if self.config.detector == DETECTOR_RETINANET:
|
||||
if self.config.detector == DETECTOR_RESNET:
|
||||
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
||||
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
||||
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
||||
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.35)
|
||||
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20)
|
||||
self.model.to(self.device)
|
||||
# Put the model in inference mode
|
||||
self.model.eval()
|
||||
# Get the transforms for the model's weights
|
||||
self.preprocess = weights.transforms().to(self.device)
|
||||
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
||||
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||
)
|
||||
elif self.config.detector == DETECTOR_MASKRCNN:
|
||||
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
|
||||
self.model = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)
|
||||
self.model.to(self.device)
|
||||
# Put the model in inference mode
|
||||
self.model.eval()
|
||||
# Get the transforms for the model's weights
|
||||
self.preprocess = weights.transforms().to(self.device)
|
||||
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
||||
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||
)
|
||||
elif self.config.detector == DETECTOR_YOLOv8:
|
||||
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||
else:
|
||||
raise RuntimeError(f"{self.config.detector} is not implemented yet. See --help")
|
||||
raise RuntimeError("No valid detector specified. See --help")
|
||||
|
||||
|
||||
# homography = list(source.glob('*img2world.txt'))[0]
|
||||
|
||||
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
||||
|
||||
self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9)
|
||||
logger.debug("Set up tracker")
|
||||
|
||||
|
||||
|
@ -122,29 +124,17 @@ class Tracker:
|
|||
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||
|
||||
prev_frame_i = -1
|
||||
|
||||
frame_i = 0
|
||||
while self.is_running.is_set():
|
||||
# this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
|
||||
# skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
|
||||
# so for now, timing should move to emitter
|
||||
# this_run_time = time.time()
|
||||
# # logger.debug(f'test {prev_run_time - this_run_time}')
|
||||
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
||||
# prev_run_time = time.time()
|
||||
|
||||
start_time = time.time()
|
||||
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
||||
|
||||
if frame.index > (prev_frame_i+1):
|
||||
logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})")
|
||||
|
||||
|
||||
prev_frame_i = frame.index
|
||||
# load homography into frame (TODO: should this be done in emitter?)
|
||||
frame.H = self.H
|
||||
this_run_time = time.time()
|
||||
# logger.debug(f'test {prev_run_time - this_run_time}')
|
||||
time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
||||
prev_run_time = time.time()
|
||||
|
||||
msg = self.frame_sock.recv()
|
||||
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
|
||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
if self.config.detector == DETECTOR_YOLOv8:
|
||||
|
@ -154,44 +144,38 @@ class Tracker:
|
|||
|
||||
|
||||
# Store detections into tracklets
|
||||
projected_coordinates = []
|
||||
for detection in detections:
|
||||
track = self.tracks[detection.track_id]
|
||||
track.track_id = detection.track_id # for new tracks
|
||||
|
||||
track.history.append(detection) # add to history
|
||||
# projected_coordinates.append(track.get_projected_history(self.H)) # then get full history
|
||||
|
||||
# TODO: hadle occlusions, and dissappearance
|
||||
track.history.append(detection)
|
||||
# if len(track.history) > 30: # retain 90 tracks for 90 frames
|
||||
# track.history.pop(0)
|
||||
|
||||
|
||||
# trajectories = {}
|
||||
# for detection in detections:
|
||||
# tid = str(detection.track_id)
|
||||
# track = self.tracks[detection.track_id]
|
||||
# coords = track.get_projected_history(self.H) # get full history
|
||||
# trajectories[tid] = {
|
||||
# "id": tid,
|
||||
# "det_conf": detection.conf,
|
||||
# "bbox": detection.to_ltwh(),
|
||||
# "history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
||||
# }
|
||||
active_track_ids = [d.track_id for d in detections]
|
||||
active_tracks = {t.track_id: t for t in self.tracks.values() if t.track_id in active_track_ids}
|
||||
# logger.info(f"{trajectories}")
|
||||
frame.tracks = active_tracks
|
||||
foot_coordinates = np.array([[t.get_foot_coords()] for t in detections])
|
||||
if len(foot_coordinates):
|
||||
projected_coordinates = cv2.perspectiveTransform(foot_coordinates,self.H)
|
||||
else:
|
||||
projected_coordinates = []
|
||||
|
||||
# print(TEMP_proj_coords)
|
||||
trajectories = {}
|
||||
for detection, coords in zip(detections, projected_coordinates):
|
||||
tid = str(detection.track_id)
|
||||
trajectories[tid] = {
|
||||
"id": tid,
|
||||
"history": [{"x":c[0], "y":c[1]} for c in coords] if not self.config.bypass_prediction else coords.tolist() # already doubles nested, fine for test
|
||||
}
|
||||
# logger.info(f"{trajectories}")
|
||||
frame.trajectories = trajectories
|
||||
|
||||
# if self.config.bypass_prediction:
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# else:
|
||||
# self.trajectory_socket.send(pickle.dumps(frame))
|
||||
self.trajectory_socket.send_pyobj(frame)
|
||||
|
||||
current_time = time.time()
|
||||
logger.debug(f"Trajectories: {len(active_tracks)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||
|
||||
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||
if self.config.bypass_prediction:
|
||||
self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
else:
|
||||
self.trajectory_socket.send(pickle.dumps(frame))
|
||||
|
||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||
# TODO: provide a track object that actually keeps history (unlike tracker)
|
||||
|
@ -200,14 +184,13 @@ class Tracker:
|
|||
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
||||
if training_csv:
|
||||
training_csv.writerows([{
|
||||
'frame_id': round(frame.index * 10., 1), # not really time
|
||||
'frame_id': round(frame_i * 10., 1), # not really time
|
||||
'track_id': t['id'],
|
||||
'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0],
|
||||
'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1],
|
||||
} for t in active_tracks.values()])
|
||||
training_frames += len(active_tracks)
|
||||
# print(time.time() - start_time)
|
||||
|
||||
'x': t['history'][-1]['x'],
|
||||
'y': t['history'][-1]['y'],
|
||||
} for t in trajectories.values()])
|
||||
training_frames += len(trajectories)
|
||||
frame_i += 1
|
||||
|
||||
if training_fp:
|
||||
training_fp.close()
|
||||
|
@ -231,9 +214,6 @@ class Tracker:
|
|||
|
||||
def _yolov8_track(self, img) -> [Detection]:
|
||||
results: [YOLOResult] = self.model.track(img, persist=True)
|
||||
if results[0].boxes is None or results[0].boxes.id is None:
|
||||
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
||||
return []
|
||||
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
||||
|
||||
def _resnet_track(self, img) -> [Detection]:
|
||||
|
@ -241,7 +221,7 @@ class Tracker:
|
|||
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
||||
return [Detection.from_deepsort(t) for t in tracks]
|
||||
|
||||
def _resnet_detect_persons(self, frame) -> [Detection]:
|
||||
def _resnet_detect_persons(self, frame) -> Detections:
|
||||
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
||||
t = t.permute(2, 0, 1)
|
||||
|
|
Loading…
Reference in a new issue