Compare commits
10 commits
185962ace5
...
cd8a06a53b
Author | SHA1 | Date | |
---|---|---|---|
|
cd8a06a53b | ||
|
44a618a5ee | ||
|
a3e42b4501 | ||
|
f3b8e031c1 | ||
|
ec9bb357fd | ||
|
104098d371 | ||
|
1aed4161f6 | ||
|
fd2e8a3b49 | ||
|
3091557733 | ||
|
b911caa5af |
10 changed files with 1392 additions and 1253 deletions
1813
poetry.lock
generated
1813
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -26,6 +26,9 @@ torchvision = [
|
||||||
]
|
]
|
||||||
deep-sort-realtime = "^1.3.2"
|
deep-sort-realtime = "^1.3.2"
|
||||||
ultralytics = "^8.0.200"
|
ultralytics = "^8.0.200"
|
||||||
|
ffmpeg-python = "^0.2.0"
|
||||||
|
torchreid = "^0.2.5"
|
||||||
|
gdown = "^4.7.1"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|
383
test_homography.ipynb
Normal file
383
test_homography.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -194,8 +194,8 @@ frame_emitter_parser.add_argument("--video-src",
|
||||||
default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')))
|
default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')))
|
||||||
#TODO: camera as source
|
#TODO: camera as source
|
||||||
|
|
||||||
frame_emitter_parser.add_argument("--video-no-loop",
|
frame_emitter_parser.add_argument("--video-loop",
|
||||||
help="By default it emitter will run indefiniately. This prevents that and plays every video only once.",
|
help="By default it emitter will run only once. This allows it to loop the video file to keep testing.",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
#TODO: camera as source
|
#TODO: camera as source
|
||||||
|
|
||||||
|
@ -217,7 +217,17 @@ tracker_parser.add_argument("--detector",
|
||||||
|
|
||||||
# Renderer
|
# Renderer
|
||||||
|
|
||||||
render_parser.add_argument("--render-preview",
|
render_parser.add_argument("--render-file",
|
||||||
help="Render a video file previewing the prediction, and its delay compared to the current frame",
|
help="Render a video file previewing the prediction, and its delay compared to the current frame",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
|
||||||
|
render_parser.add_argument("--render-url",
|
||||||
|
help="""Stream renderer on given URL. Two easy approaches:
|
||||||
|
- using zmq wrapper one can specify the LISTENING ip. To listen to any incoming connection: zmq:tcp://0.0.0.0:5556
|
||||||
|
- alternatively, using e.g. UDP one needs to specify the IP of the client. E.g. udp://100.69.123.91:5556/stream
|
||||||
|
Note that with ZMQ you can have multiple clients connecting simultaneously. E.g. using `ffplay zmq:tcp://100.109.175.82:5556`
|
||||||
|
When using udp, connecting can be done using `ffplay udp://100.109.175.82:5556/stream`
|
||||||
|
""",
|
||||||
|
type=str,
|
||||||
|
default=None)
|
||||||
|
|
||||||
|
|
|
@ -3,21 +3,73 @@ from dataclasses import dataclass, field
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Event
|
from multiprocessing import Event
|
||||||
|
from pathlib import Path
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Iterable, Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import zmq
|
import zmq
|
||||||
|
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||||
|
|
||||||
logger = logging.getLogger('trap.frame_emitter')
|
logger = logging.getLogger('trap.frame_emitter')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Detection:
|
||||||
|
track_id: str # deepsort track id association
|
||||||
|
l: int # left - image space
|
||||||
|
t: int # top - image space
|
||||||
|
w: int # width - image space
|
||||||
|
h: int # height - image space
|
||||||
|
conf: float # object detector probablity
|
||||||
|
|
||||||
|
def get_foot_coords(self):
|
||||||
|
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||||
|
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
|
||||||
|
|
||||||
|
def to_ltwh(self):
|
||||||
|
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
||||||
|
|
||||||
|
def to_ltrb(self):
|
||||||
|
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Track:
|
||||||
|
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
|
||||||
|
a history, with which the predictor can work, as we then can deduce velocity
|
||||||
|
and acceleration.
|
||||||
|
"""
|
||||||
|
track_id: str = None
|
||||||
|
history: [Detection] = field(default_factory=lambda: [])
|
||||||
|
predictor_history: Optional[list] = None # in image space
|
||||||
|
predictions: Optional[list] = None
|
||||||
|
|
||||||
|
def get_projected_history(self, H) -> np.array:
|
||||||
|
foot_coordinates = [d.get_foot_coords() for d in self.history]
|
||||||
|
|
||||||
|
if len(foot_coordinates):
|
||||||
|
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
|
||||||
|
return coords[0]
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
def get_projected_history_as_dict(self, H) -> dict:
|
||||||
|
coords = self.get_projected_history(H)
|
||||||
|
return [{"x":c[0], "y":c[1]} for c in coords]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Frame:
|
class Frame:
|
||||||
|
index: int
|
||||||
img: np.array
|
img: np.array
|
||||||
time: float= field(default_factory=lambda: time.time())
|
time: float= field(default_factory=lambda: time.time())
|
||||||
trajectories: Optional[dict] = None
|
tracks: Optional[dict[str, Track]] = None
|
||||||
|
H: Optional[np.array] = None
|
||||||
|
|
||||||
class FrameEmitter:
|
class FrameEmitter:
|
||||||
'''
|
'''
|
||||||
|
@ -36,10 +88,10 @@ class FrameEmitter:
|
||||||
|
|
||||||
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
||||||
|
|
||||||
if self.config.video_no_loop:
|
if self.config.video_loop:
|
||||||
self.video_srcs = self.config.video_src
|
self.video_srcs: Iterable[Path] = cycle(self.config.video_src)
|
||||||
else:
|
else:
|
||||||
self.video_srcs = cycle(self.config.video_src)
|
self.video_srcs: [Path] = self.config.video_src
|
||||||
|
|
||||||
|
|
||||||
def emit_video(self):
|
def emit_video(self):
|
||||||
|
@ -51,6 +103,7 @@ class FrameEmitter:
|
||||||
logger.info(f"Emit frames at {fps} fps")
|
logger.info(f"Emit frames at {fps} fps")
|
||||||
|
|
||||||
prev_time = time.time()
|
prev_time = time.time()
|
||||||
|
i = 0
|
||||||
while self.is_running.is_set():
|
while self.is_running.is_set():
|
||||||
ret, img = video.read()
|
ret, img = video.read()
|
||||||
|
|
||||||
|
@ -61,7 +114,13 @@ class FrameEmitter:
|
||||||
# video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
# video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||||
# ret, img = video.read()
|
# ret, img = video.read()
|
||||||
# assert ret is not False # not really error proof...
|
# assert ret is not False # not really error proof...
|
||||||
frame = Frame(img=img)
|
|
||||||
|
|
||||||
|
if "DATASETS/hof/" in str(video_path):
|
||||||
|
# hack to mask out area
|
||||||
|
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
|
||||||
|
|
||||||
|
frame = Frame(index=i, img=img)
|
||||||
# TODO: this is very dirty, need to find another way.
|
# TODO: this is very dirty, need to find another way.
|
||||||
# perhaps multiprocessing Array?
|
# perhaps multiprocessing Array?
|
||||||
self.frame_sock.send(pickle.dumps(frame))
|
self.frame_sock.send(pickle.dumps(frame))
|
||||||
|
@ -75,10 +134,13 @@ class FrameEmitter:
|
||||||
else:
|
else:
|
||||||
prev_time = new_frame_time
|
prev_time = new_frame_time
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
if not self.is_running.is_set():
|
if not self.is_running.is_set():
|
||||||
# if not running, also break out of infinite generator loop
|
# if not running, also break out of infinite generator loop
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
logger.info("Stopping")
|
logger.info("Stopping")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ def start():
|
||||||
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
||||||
]
|
]
|
||||||
|
|
||||||
if args.render_preview:
|
if args.render_file or args.render_url:
|
||||||
procs.append(
|
procs.append(
|
||||||
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer')
|
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer')
|
||||||
)
|
)
|
||||||
|
|
|
@ -27,6 +27,7 @@ import matplotlib.pyplot as plt
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from trap.frame_emitter import Frame
|
from trap.frame_emitter import Frame
|
||||||
|
from trap.tracker import Track
|
||||||
|
|
||||||
logger = logging.getLogger("trap.prediction")
|
logger = logging.getLogger("trap.prediction")
|
||||||
|
|
||||||
|
@ -242,33 +243,38 @@ class PredictionServer:
|
||||||
if self.config.predict_training_data:
|
if self.config.predict_training_data:
|
||||||
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
||||||
else:
|
else:
|
||||||
|
zmq_ev = self.trajectory_socket.poll(timeout=3)
|
||||||
|
if not zmq_ev:
|
||||||
|
# on no data loop so that is_running is checked
|
||||||
|
continue
|
||||||
|
|
||||||
data = self.trajectory_socket.recv()
|
data = self.trajectory_socket.recv()
|
||||||
frame: Frame = pickle.loads(data)
|
frame: Frame = pickle.loads(data)
|
||||||
trajectory_data = frame.trajectories # TODO: properly refractor
|
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
|
||||||
# trajectory_data = json.loads(data)
|
# trajectory_data = json.loads(data)
|
||||||
logger.debug(f"Receive {trajectory_data}")
|
logger.debug(f"Receive {frame.index}")
|
||||||
|
|
||||||
# class FakeNode:
|
# class FakeNode:
|
||||||
# def __init__(self, node_type: NodeType):
|
# def __init__(self, node_type: NodeType):
|
||||||
# self.type = node_type
|
# self.type = node_type
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
for identifier, trajectory in trajectory_data.items():
|
for identifier, track in frame.tracks.items():
|
||||||
# if len(trajectory['history']) < 7:
|
# if len(trajectory['history']) < 7:
|
||||||
# # TODO: these trajectories should still be in the output, but without predictions
|
# # TODO: these trajectories should still be in the output, but without predictions
|
||||||
# continue
|
# continue
|
||||||
|
|
||||||
# TODO: modify this into a mapping function between JS data an the expected Node format
|
# TODO: modify this into a mapping function between JS data an the expected Node format
|
||||||
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
||||||
history = [[h['x'], h['y']] for h in trajectory['history']]
|
history = [[h['x'], h['y']] for h in track.get_projected_history_as_dict(frame.H)]
|
||||||
history = np.array(history)
|
history = np.array(history)
|
||||||
x = history[:, 0]
|
x = history[:, 0]
|
||||||
y = history[:, 1]
|
y = history[:, 1]
|
||||||
# TODO: calculate dt based on input
|
# TODO: calculate dt based on input
|
||||||
vx = derivative_of(x, 0.2) #eval_scene.dt
|
vx = derivative_of(x, 0.1) #eval_scene.dt
|
||||||
vy = derivative_of(y, 0.2)
|
vy = derivative_of(y, 0.1)
|
||||||
ax = derivative_of(vx, 0.2)
|
ax = derivative_of(vx, 0.1)
|
||||||
ay = derivative_of(vy, 0.2)
|
ay = derivative_of(vy, 0.1)
|
||||||
|
|
||||||
data_dict = {('position', 'x'): x[:],
|
data_dict = {('position', 'x'): x[:],
|
||||||
('position', 'y'): y[:],
|
('position', 'y'): y[:],
|
||||||
|
@ -296,7 +302,7 @@ class PredictionServer:
|
||||||
# And want to update the network
|
# And want to update the network
|
||||||
|
|
||||||
data = json.dumps({})
|
data = json.dumps({})
|
||||||
self.prediction_socket.send_string(data)
|
self.prediction_socket.send_pyobj(frame)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -320,7 +326,7 @@ class PredictionServer:
|
||||||
dists, preds = trajectron.incremental_forward(input_dict,
|
dists, preds = trajectron.incremental_forward(input_dict,
|
||||||
maps,
|
maps,
|
||||||
prediction_horizon=25, # TODO: make variable
|
prediction_horizon=25, # TODO: make variable
|
||||||
num_samples=20, # TODO: make variable
|
num_samples=5, # TODO: make variable
|
||||||
robot_present_and_future=robot_present_and_future,
|
robot_present_and_future=robot_present_and_future,
|
||||||
full_dist=True)
|
full_dist=True)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
@ -338,6 +344,7 @@ class PredictionServer:
|
||||||
# prediction_dict provides the actual predictions
|
# prediction_dict provides the actual predictions
|
||||||
# histories_dict provides the trajectory used for prediction
|
# histories_dict provides the trajectory used for prediction
|
||||||
# futures_dict is the Ground Truth, which is unvailable in an online setting
|
# futures_dict is the Ground Truth, which is unvailable in an online setting
|
||||||
|
|
||||||
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
|
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
|
||||||
eval_scene.dt,
|
eval_scene.dt,
|
||||||
hyperparams['maximum_history_length'],
|
hyperparams['maximum_history_length'],
|
||||||
|
@ -358,24 +365,29 @@ class PredictionServer:
|
||||||
|
|
||||||
for node in histories_dict:
|
for node in histories_dict:
|
||||||
history = histories_dict[node]
|
history = histories_dict[node]
|
||||||
# future = futures_dict[node]
|
# future = futures_dict[node] # ground truth dict
|
||||||
predictions = prediction_dict[node]
|
predictions = prediction_dict[node]
|
||||||
|
|
||||||
if not len(history) or np.isnan(history[-1]).any():
|
if not len(history) or np.isnan(history[-1]).any():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response[node.id] = {
|
# response[node.id] = {
|
||||||
'id': node.id,
|
# 'id': node.id,
|
||||||
'history': history.tolist(),
|
# 'det_conf': trajectory_data[node.id]['det_conf'],
|
||||||
'predictions': predictions[0].tolist() # use batch 0
|
# 'bbox': trajectory_data[node.id]['bbox'],
|
||||||
}
|
# 'history': history.tolist(),
|
||||||
|
# 'predictions': predictions[0].tolist() # use batch 0
|
||||||
|
# }
|
||||||
|
|
||||||
data = json.dumps(response)
|
frame.tracks[node.id].predictor_history = history.tolist()
|
||||||
|
frame.tracks[node.id].predictions = predictions[0].tolist() # use batch 0
|
||||||
|
|
||||||
|
# data = json.dumps(response)
|
||||||
if self.config.predict_training_data:
|
if self.config.predict_training_data:
|
||||||
logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s")
|
logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
||||||
self.prediction_socket.send_string(data)
|
self.prediction_socket.send_pyobj(frame)
|
||||||
logger.info('Stopping')
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
||||||
|
|
128
trap/renderer.py
128
trap/renderer.py
|
@ -1,4 +1,4 @@
|
||||||
|
import ffmpeg
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
|
@ -33,27 +33,65 @@ class Renderer:
|
||||||
|
|
||||||
self.inv_H = np.linalg.pinv(self.H)
|
self.inv_H = np.linalg.pinv(self.H)
|
||||||
|
|
||||||
|
# TODO: get FPS from frame_emitter
|
||||||
|
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
||||||
|
self.fps = 10
|
||||||
|
self.frame_size = (1280,720)
|
||||||
|
self.out_writer = self.start_writer() if self.config.render_file else None
|
||||||
|
self.streaming_process = self.start_streaming() if self.config.render_url else None
|
||||||
|
|
||||||
|
def start_writer(self):
|
||||||
if not self.config.output_dir.exists():
|
if not self.config.output_dir.exists():
|
||||||
raise FileNotFoundError("Path does not exist")
|
raise FileNotFoundError("Path does not exist")
|
||||||
|
|
||||||
date_str = datetime.datetime.now().isoformat(timespec="minutes")
|
date_str = datetime.datetime.now().isoformat(timespec="minutes")
|
||||||
filename = self.config.output_dir / f"render_predictions-{date_str}.mp4"
|
filename = self.config.output_dir / f"render_predictions-{date_str}-{self.config.detector}.mp4"
|
||||||
logger.info(f"Write to {filename}")
|
logger.info(f"Write to {filename}")
|
||||||
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'vp09')
|
fourcc = cv2.VideoWriter_fourcc(*'vp09')
|
||||||
# TODO: get FPS from frame_emitter
|
|
||||||
self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
return cv2.VideoWriter(str(filename), fourcc, self.fps, self.frame_size)
|
||||||
|
|
||||||
|
def start_streaming(self):
|
||||||
|
return (
|
||||||
|
ffmpeg
|
||||||
|
.input('pipe:', format='rawvideo',codec="rawvideo", pix_fmt='bgr24', s='{}x{}'.format(*self.frame_size))
|
||||||
|
.output(
|
||||||
|
self.config.render_url,
|
||||||
|
#codec = "copy", # use same codecs of the original video
|
||||||
|
codec='libx264',
|
||||||
|
listen=1, # enables HTTP server
|
||||||
|
pix_fmt="yuv420p",
|
||||||
|
preset="ultrafast",
|
||||||
|
tune="zerolatency",
|
||||||
|
g=f"{self.fps*2}",
|
||||||
|
analyzeduration="2000000",
|
||||||
|
probesize="1000000",
|
||||||
|
f='mpegts'
|
||||||
|
)
|
||||||
|
.overwrite_output()
|
||||||
|
.run_async(pipe_stdin=True)
|
||||||
|
)
|
||||||
|
# return process
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
predictions = {}
|
prediction_frame = None
|
||||||
i=0
|
i=0
|
||||||
first_time = None
|
first_time = None
|
||||||
while self.is_running.is_set():
|
while self.is_running.is_set():
|
||||||
i+=1
|
i+=1
|
||||||
|
|
||||||
|
zmq_ev = self.frame_sock.poll(timeout=3)
|
||||||
|
if not zmq_ev:
|
||||||
|
# when no data comes in, loop so that is_running is checked
|
||||||
|
continue
|
||||||
|
|
||||||
frame: Frame = self.frame_sock.recv_pyobj()
|
frame: Frame = self.frame_sock.recv_pyobj()
|
||||||
try:
|
try:
|
||||||
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
|
prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
logger.debug(f'reuse prediction')
|
logger.debug(f'reuse prediction')
|
||||||
|
|
||||||
|
@ -68,47 +106,77 @@ class Renderer:
|
||||||
# warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000))
|
# warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000))
|
||||||
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
|
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
|
||||||
|
|
||||||
|
if not prediction_frame:
|
||||||
|
cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for track_id, track in prediction_frame.tracks.items():
|
||||||
|
if not len(track.history):
|
||||||
|
continue
|
||||||
|
|
||||||
for track_id, prediction in predictions.items():
|
# coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||||
if not 'history' in prediction or not len(prediction['history']):
|
coords = [d.get_foot_coords() for d in track.history]
|
||||||
continue
|
# logger.warning(f"{coords=}")
|
||||||
|
|
||||||
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
for ci in range(1, len(coords)):
|
||||||
# logger.warning(f"{coords=}")
|
start = [int(p) for p in coords[ci-1]]
|
||||||
center = [int(p) for p in coords[-1]]
|
end = [int(p) for p in coords[ci]]
|
||||||
cv2.circle(img, center, 5, (0,255,0))
|
cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA)
|
||||||
cv2.putText(img, track_id, (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.8, color=(0,255,0))
|
|
||||||
|
|
||||||
for ci in range(1, len(coords)):
|
if not track.predictions or not len(track.predictions):
|
||||||
start = [int(p) for p in coords[ci-1]]
|
continue
|
||||||
end = [int(p) for p in coords[ci]]
|
|
||||||
cv2.line(img, start, end, (255,255,255), 2)
|
|
||||||
|
|
||||||
if not 'predictions' in prediction or not len(prediction['predictions']):
|
for pred_i, pred in enumerate(track.predictions):
|
||||||
continue
|
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||||
|
color = (0,0,255) if pred_i == 1 else (100,100,100)
|
||||||
for pred in prediction['predictions']:
|
for ci in range(1, len(pred_coords)):
|
||||||
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
start = [int(p) for p in pred_coords[ci-1]]
|
||||||
for ci in range(1, len(pred_coords)):
|
end = [int(p) for p in pred_coords[ci]]
|
||||||
start = [int(p) for p in pred_coords[ci-1]]
|
cv2.line(img, start, end, color, 1, lineType=cv2.LINE_AA)
|
||||||
end = [int(p) for p in pred_coords[ci]]
|
|
||||||
cv2.line(img, start, end, (0,0,255), 1)
|
|
||||||
|
|
||||||
|
for track_id, track in prediction_frame.tracks.items():
|
||||||
|
# draw tracker marker and track id last so it lies over the trajectories
|
||||||
|
# this goes is a second loop so it overlays over _all_ trajectories
|
||||||
|
# coords = cv2.perspectiveTransform(np.array([[track.history[-1].get_foot_coords()]]), self.inv_H)[0]
|
||||||
|
coords = track.history[-1].get_foot_coords()
|
||||||
|
|
||||||
|
center = [int(p) for p in coords]
|
||||||
|
cv2.circle(img, center, 5, (0,255,0))
|
||||||
|
(l, t, r, b) = track.history[-1].to_ltrb()
|
||||||
|
p1 = (l, t)
|
||||||
|
p2 = (r, b)
|
||||||
|
cv2.rectangle(img, p1, p2, (255,0,0), 1)
|
||||||
|
cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
|
||||||
|
|
||||||
if first_time is None:
|
if first_time is None:
|
||||||
first_time = frame.time
|
first_time = frame.time
|
||||||
|
|
||||||
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||||
|
cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||||
|
|
||||||
|
if prediction_frame:
|
||||||
|
cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
||||||
|
|
||||||
# cv2.imwrite(str(img_path), img)
|
# cv2.imwrite(str(img_path), img)
|
||||||
self.out.write(img)
|
logger.info(f"write frame {frame.time - first_time:.3f}s")
|
||||||
|
if self.out_writer:
|
||||||
|
self.out_writer.write(img)
|
||||||
|
if self.streaming_process:
|
||||||
|
self.streaming_process.stdin.write(img.tobytes())
|
||||||
logger.info('Stopping')
|
logger.info('Stopping')
|
||||||
|
|
||||||
if i>2:
|
if i>2:
|
||||||
self.out.release()
|
if self.streaming_process:
|
||||||
|
self.streaming_process.stdin.close()
|
||||||
|
if self.out_writer:
|
||||||
|
self.out_writer.release()
|
||||||
|
if self.streaming_process:
|
||||||
|
# oddly wrapped, because both close and release() take time.
|
||||||
|
self.streaming_process.wait()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import errno
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Event
|
from multiprocessing import Event
|
||||||
|
import subprocess
|
||||||
from typing import Set, Union, Dict, Any
|
from typing import Set, Union, Dict, Any
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
@ -145,7 +147,14 @@ class WsRouter:
|
||||||
|
|
||||||
# loop = tornado.ioloop.IOLoop.current()
|
# loop = tornado.ioloop.IOLoop.current()
|
||||||
logger.info(f"Listen on {self.config.ws_port}")
|
logger.info(f"Listen on {self.config.ws_port}")
|
||||||
self.application.listen(self.config.ws_port)
|
try:
|
||||||
|
self.application.listen(self.config.ws_port)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == errno.EADDRINUSE:
|
||||||
|
logger.critical("Address already in use by process")
|
||||||
|
subprocess.run(["lsof", "-i", f"tcp:{self.config.ws_port:d}"])
|
||||||
|
raise
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
task = self.evt_loop.create_task(self.prediction_forwarder())
|
task = self.evt_loop.create_task(self.prediction_forwarder())
|
||||||
|
@ -156,9 +165,12 @@ class WsRouter:
|
||||||
async def prediction_forwarder(self):
|
async def prediction_forwarder(self):
|
||||||
logger.info("Starting prediction forwarder")
|
logger.info("Starting prediction forwarder")
|
||||||
while self.is_running.is_set():
|
while self.is_running.is_set():
|
||||||
msg = await self.prediction_socket.recv_string()
|
# timeout so that if no events occur, loop can still stop on is_running
|
||||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
has_event = await self.prediction_socket.poll(timeout=1)
|
||||||
WebSocketPredictionHandler.write_to_clients(msg)
|
if has_event:
|
||||||
|
msg = await self.prediction_socket.recv_string()
|
||||||
|
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||||
|
WebSocketPredictionHandler.write_to_clients(msg)
|
||||||
|
|
||||||
# die together:
|
# die together:
|
||||||
self.evt_loop.stop()
|
self.evt_loop.stop()
|
||||||
|
|
150
trap/tracker.py
150
trap/tracker.py
|
@ -8,22 +8,24 @@ from multiprocessing import Event
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
|
||||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||||
|
from torchvision.models import ResNet50_Weights
|
||||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from ultralytics.engine.results import Results as YOLOResult
|
from ultralytics.engine.results import Results as YOLOResult
|
||||||
|
|
||||||
from trap.frame_emitter import Frame
|
from trap.frame_emitter import Frame, Detection, Track
|
||||||
|
|
||||||
Detection = [int, int, int, int, float, int]
|
# Detection = [int, int, int, int, float, int]
|
||||||
Detections = [Detection]
|
# Detections = [Detection]
|
||||||
|
|
||||||
# This is the dt that is also used by the scene.
|
# This is the dt that is also used by the scene.
|
||||||
# as this needs to be rather stable, try to adhere
|
# as this needs to be rather stable, try to adhere
|
||||||
|
@ -33,30 +35,13 @@ TARGET_DT = .1
|
||||||
|
|
||||||
logger = logging.getLogger("trap.tracker")
|
logger = logging.getLogger("trap.tracker")
|
||||||
|
|
||||||
DETECTOR_RESNET = 'resnet'
|
DETECTOR_RETINANET = 'retinanet'
|
||||||
|
DETECTOR_MASKRCNN = 'maskrcnn'
|
||||||
|
DETECTOR_FASTERRCNN = 'fasterrcnn'
|
||||||
DETECTOR_YOLOv8 = 'ultralytics'
|
DETECTOR_YOLOv8 = 'ultralytics'
|
||||||
|
|
||||||
DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8]
|
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Track:
|
|
||||||
track_id: str = None
|
|
||||||
history: [Detection]= field(default_factory=lambda: [])
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Detection:
|
|
||||||
track_id: str
|
|
||||||
l: int # left
|
|
||||||
t: int # top
|
|
||||||
w: int # width
|
|
||||||
h: int # height
|
|
||||||
|
|
||||||
def get_foot_coords(self):
|
|
||||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
|
||||||
return cls(dstrack.track_id, *dstrack.to_ltwh())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,27 +66,40 @@ class Tracker:
|
||||||
# TODO: support removal
|
# TODO: support removal
|
||||||
self.tracks = defaultdict(lambda: Track())
|
self.tracks = defaultdict(lambda: Track())
|
||||||
|
|
||||||
if self.config.detector == DETECTOR_RESNET:
|
if self.config.detector == DETECTOR_RETINANET:
|
||||||
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
||||||
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
||||||
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
||||||
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20)
|
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.35)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
# Put the model in inference mode
|
# Put the model in inference mode
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
# Get the transforms for the model's weights
|
# Get the transforms for the model's weights
|
||||||
self.preprocess = weights.transforms().to(self.device)
|
self.preprocess = weights.transforms().to(self.device)
|
||||||
|
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
||||||
|
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||||
|
)
|
||||||
|
elif self.config.detector == DETECTOR_MASKRCNN:
|
||||||
|
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
|
||||||
|
self.model = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)
|
||||||
|
self.model.to(self.device)
|
||||||
|
# Put the model in inference mode
|
||||||
|
self.model.eval()
|
||||||
|
# Get the transforms for the model's weights
|
||||||
|
self.preprocess = weights.transforms().to(self.device)
|
||||||
|
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
||||||
|
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
||||||
|
)
|
||||||
elif self.config.detector == DETECTOR_YOLOv8:
|
elif self.config.detector == DETECTOR_YOLOv8:
|
||||||
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("No valid detector specified. See --help")
|
raise RuntimeError(f"{self.config.detector} is not implemented yet. See --help")
|
||||||
|
|
||||||
|
|
||||||
# homography = list(source.glob('*img2world.txt'))[0]
|
# homography = list(source.glob('*img2world.txt'))[0]
|
||||||
|
|
||||||
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
||||||
|
|
||||||
self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9)
|
|
||||||
logger.debug("Set up tracker")
|
logger.debug("Set up tracker")
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,17 +122,29 @@ class Tracker:
|
||||||
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||||
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||||
|
|
||||||
frame_i = 0
|
prev_frame_i = -1
|
||||||
while self.is_running.is_set():
|
|
||||||
this_run_time = time.time()
|
while self.is_running.is_set():
|
||||||
# logger.debug(f'test {prev_run_time - this_run_time}')
|
# this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
|
||||||
time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
# skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
|
||||||
prev_run_time = time.time()
|
# so for now, timing should move to emitter
|
||||||
|
# this_run_time = time.time()
|
||||||
|
# # logger.debug(f'test {prev_run_time - this_run_time}')
|
||||||
|
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
||||||
|
# prev_run_time = time.time()
|
||||||
|
|
||||||
msg = self.frame_sock.recv()
|
|
||||||
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
|
|
||||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
||||||
|
|
||||||
|
if frame.index > (prev_frame_i+1):
|
||||||
|
logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})")
|
||||||
|
|
||||||
|
|
||||||
|
prev_frame_i = frame.index
|
||||||
|
# load homography into frame (TODO: should this be done in emitter?)
|
||||||
|
frame.H = self.H
|
||||||
|
|
||||||
|
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||||
|
|
||||||
|
|
||||||
if self.config.detector == DETECTOR_YOLOv8:
|
if self.config.detector == DETECTOR_YOLOv8:
|
||||||
|
@ -144,37 +154,43 @@ class Tracker:
|
||||||
|
|
||||||
|
|
||||||
# Store detections into tracklets
|
# Store detections into tracklets
|
||||||
|
projected_coordinates = []
|
||||||
for detection in detections:
|
for detection in detections:
|
||||||
track = self.tracks[detection.track_id]
|
track = self.tracks[detection.track_id]
|
||||||
track.track_id = detection.track_id # for new tracks
|
track.track_id = detection.track_id # for new tracks
|
||||||
|
|
||||||
track.history.append(detection)
|
track.history.append(detection) # add to history
|
||||||
|
# projected_coordinates.append(track.get_projected_history(self.H)) # then get full history
|
||||||
|
|
||||||
|
# TODO: hadle occlusions, and dissappearance
|
||||||
# if len(track.history) > 30: # retain 90 tracks for 90 frames
|
# if len(track.history) > 30: # retain 90 tracks for 90 frames
|
||||||
# track.history.pop(0)
|
# track.history.pop(0)
|
||||||
|
|
||||||
foot_coordinates = np.array([[t.get_foot_coords()] for t in detections])
|
|
||||||
if len(foot_coordinates):
|
|
||||||
projected_coordinates = cv2.perspectiveTransform(foot_coordinates,self.H)
|
|
||||||
else:
|
|
||||||
projected_coordinates = []
|
|
||||||
|
|
||||||
# print(TEMP_proj_coords)
|
# trajectories = {}
|
||||||
trajectories = {}
|
# for detection in detections:
|
||||||
for detection, coords in zip(detections, projected_coordinates):
|
# tid = str(detection.track_id)
|
||||||
tid = str(detection.track_id)
|
# track = self.tracks[detection.track_id]
|
||||||
trajectories[tid] = {
|
# coords = track.get_projected_history(self.H) # get full history
|
||||||
"id": tid,
|
# trajectories[tid] = {
|
||||||
"history": [{"x":c[0], "y":c[1]} for c in coords] if not self.config.bypass_prediction else coords.tolist() # already doubles nested, fine for test
|
# "id": tid,
|
||||||
}
|
# "det_conf": detection.conf,
|
||||||
|
# "bbox": detection.to_ltwh(),
|
||||||
|
# "history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
||||||
|
# }
|
||||||
|
active_track_ids = [d.track_id for d in detections]
|
||||||
|
active_tracks = {t.track_id: t for t in self.tracks.values() if t.track_id in active_track_ids}
|
||||||
# logger.info(f"{trajectories}")
|
# logger.info(f"{trajectories}")
|
||||||
frame.trajectories = trajectories
|
frame.tracks = active_tracks
|
||||||
|
|
||||||
|
# if self.config.bypass_prediction:
|
||||||
|
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
|
# else:
|
||||||
|
# self.trajectory_socket.send(pickle.dumps(frame))
|
||||||
|
self.trajectory_socket.send_pyobj(frame)
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
logger.debug(f"Trajectories: {len(active_tracks)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||||
if self.config.bypass_prediction:
|
|
||||||
self.trajectory_socket.send_string(json.dumps(trajectories))
|
|
||||||
else:
|
|
||||||
self.trajectory_socket.send(pickle.dumps(frame))
|
|
||||||
|
|
||||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||||
|
@ -184,13 +200,14 @@ class Tracker:
|
||||||
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
||||||
if training_csv:
|
if training_csv:
|
||||||
training_csv.writerows([{
|
training_csv.writerows([{
|
||||||
'frame_id': round(frame_i * 10., 1), # not really time
|
'frame_id': round(frame.index * 10., 1), # not really time
|
||||||
'track_id': t['id'],
|
'track_id': t['id'],
|
||||||
'x': t['history'][-1]['x'],
|
'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0],
|
||||||
'y': t['history'][-1]['y'],
|
'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1],
|
||||||
} for t in trajectories.values()])
|
} for t in active_tracks.values()])
|
||||||
training_frames += len(trajectories)
|
training_frames += len(active_tracks)
|
||||||
frame_i += 1
|
# print(time.time() - start_time)
|
||||||
|
|
||||||
|
|
||||||
if training_fp:
|
if training_fp:
|
||||||
training_fp.close()
|
training_fp.close()
|
||||||
|
@ -214,6 +231,9 @@ class Tracker:
|
||||||
|
|
||||||
def _yolov8_track(self, img) -> [Detection]:
|
def _yolov8_track(self, img) -> [Detection]:
|
||||||
results: [YOLOResult] = self.model.track(img, persist=True)
|
results: [YOLOResult] = self.model.track(img, persist=True)
|
||||||
|
if results[0].boxes is None or results[0].boxes.id is None:
|
||||||
|
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
||||||
|
return []
|
||||||
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
||||||
|
|
||||||
def _resnet_track(self, img) -> [Detection]:
|
def _resnet_track(self, img) -> [Detection]:
|
||||||
|
@ -221,7 +241,7 @@ class Tracker:
|
||||||
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
||||||
return [Detection.from_deepsort(t) for t in tracks]
|
return [Detection.from_deepsort(t) for t in tracks]
|
||||||
|
|
||||||
def _resnet_detect_persons(self, frame) -> Detections:
|
def _resnet_detect_persons(self, frame) -> [Detection]:
|
||||||
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
||||||
t = t.permute(2, 0, 1)
|
t = t.permute(2, 0, 1)
|
||||||
|
|
Loading…
Reference in a new issue