Compare commits
No commits in common. "cd8a06a53bc0da129a2fba26a1b56b93d7fa85fb" and "185962ace5c777e39205ae8ff603107f3d381e75" have entirely different histories.
cd8a06a53b
...
185962ace5
10 changed files with 1254 additions and 1393 deletions
1813
poetry.lock
generated
1813
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -26,9 +26,6 @@ torchvision = [
|
||||||
]
|
]
|
||||||
deep-sort-realtime = "^1.3.2"
|
deep-sort-realtime = "^1.3.2"
|
||||||
ultralytics = "^8.0.200"
|
ultralytics = "^8.0.200"
|
||||||
ffmpeg-python = "^0.2.0"
|
|
||||||
torchreid = "^0.2.5"
|
|
||||||
gdown = "^4.7.1"
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -194,8 +194,8 @@ frame_emitter_parser.add_argument("--video-src",
|
||||||
default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')))
|
default=lambda: list(Path('../DATASETS/VIRAT_subset_0102x/').glob('*.mp4')))
|
||||||
#TODO: camera as source
|
#TODO: camera as source
|
||||||
|
|
||||||
frame_emitter_parser.add_argument("--video-loop",
|
frame_emitter_parser.add_argument("--video-no-loop",
|
||||||
help="By default it emitter will run only once. This allows it to loop the video file to keep testing.",
|
help="By default it emitter will run indefiniately. This prevents that and plays every video only once.",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
#TODO: camera as source
|
#TODO: camera as source
|
||||||
|
|
||||||
|
@ -217,17 +217,7 @@ tracker_parser.add_argument("--detector",
|
||||||
|
|
||||||
# Renderer
|
# Renderer
|
||||||
|
|
||||||
render_parser.add_argument("--render-file",
|
render_parser.add_argument("--render-preview",
|
||||||
help="Render a video file previewing the prediction, and its delay compared to the current frame",
|
help="Render a video file previewing the prediction, and its delay compared to the current frame",
|
||||||
action='store_true')
|
action='store_true')
|
||||||
|
|
||||||
render_parser.add_argument("--render-url",
|
|
||||||
help="""Stream renderer on given URL. Two easy approaches:
|
|
||||||
- using zmq wrapper one can specify the LISTENING ip. To listen to any incoming connection: zmq:tcp://0.0.0.0:5556
|
|
||||||
- alternatively, using e.g. UDP one needs to specify the IP of the client. E.g. udp://100.69.123.91:5556/stream
|
|
||||||
Note that with ZMQ you can have multiple clients connecting simultaneously. E.g. using `ffplay zmq:tcp://100.109.175.82:5556`
|
|
||||||
When using udp, connecting can be done using `ffplay udp://100.109.175.82:5556/stream`
|
|
||||||
""",
|
|
||||||
type=str,
|
|
||||||
default=None)
|
|
||||||
|
|
||||||
|
|
|
@ -3,73 +3,21 @@ from dataclasses import dataclass, field
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Event
|
from multiprocessing import Event
|
||||||
from pathlib import Path
|
|
||||||
import pickle
|
import pickle
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, Optional
|
from typing import Optional
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import zmq
|
import zmq
|
||||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
|
||||||
|
|
||||||
logger = logging.getLogger('trap.frame_emitter')
|
logger = logging.getLogger('trap.frame_emitter')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Detection:
|
|
||||||
track_id: str # deepsort track id association
|
|
||||||
l: int # left - image space
|
|
||||||
t: int # top - image space
|
|
||||||
w: int # width - image space
|
|
||||||
h: int # height - image space
|
|
||||||
conf: float # object detector probablity
|
|
||||||
|
|
||||||
def get_foot_coords(self):
|
|
||||||
return [self.l + 0.5 * self.w, self.t+self.h]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_deepsort(cls, dstrack: DeepsortTrack):
|
|
||||||
return cls(dstrack.track_id, *dstrack.to_ltwh(), dstrack.det_conf)
|
|
||||||
|
|
||||||
def to_ltwh(self):
|
|
||||||
return (int(self.l), int(self.t), int(self.w), int(self.h))
|
|
||||||
|
|
||||||
def to_ltrb(self):
|
|
||||||
return (int(self.l), int(self.t), int(self.l+self.w), int(self.t+self.h))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Track:
|
|
||||||
"""A bit of an haphazardous wrapper around the 'real' tracker to provide
|
|
||||||
a history, with which the predictor can work, as we then can deduce velocity
|
|
||||||
and acceleration.
|
|
||||||
"""
|
|
||||||
track_id: str = None
|
|
||||||
history: [Detection] = field(default_factory=lambda: [])
|
|
||||||
predictor_history: Optional[list] = None # in image space
|
|
||||||
predictions: Optional[list] = None
|
|
||||||
|
|
||||||
def get_projected_history(self, H) -> np.array:
|
|
||||||
foot_coordinates = [d.get_foot_coords() for d in self.history]
|
|
||||||
|
|
||||||
if len(foot_coordinates):
|
|
||||||
coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H)
|
|
||||||
return coords[0]
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
def get_projected_history_as_dict(self, H) -> dict:
|
|
||||||
coords = self.get_projected_history(H)
|
|
||||||
return [{"x":c[0], "y":c[1]} for c in coords]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Frame:
|
class Frame:
|
||||||
index: int
|
|
||||||
img: np.array
|
img: np.array
|
||||||
time: float= field(default_factory=lambda: time.time())
|
time: float= field(default_factory=lambda: time.time())
|
||||||
tracks: Optional[dict[str, Track]] = None
|
trajectories: Optional[dict] = None
|
||||||
H: Optional[np.array] = None
|
|
||||||
|
|
||||||
class FrameEmitter:
|
class FrameEmitter:
|
||||||
'''
|
'''
|
||||||
|
@ -88,10 +36,10 @@ class FrameEmitter:
|
||||||
|
|
||||||
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
||||||
|
|
||||||
if self.config.video_loop:
|
if self.config.video_no_loop:
|
||||||
self.video_srcs: Iterable[Path] = cycle(self.config.video_src)
|
self.video_srcs = self.config.video_src
|
||||||
else:
|
else:
|
||||||
self.video_srcs: [Path] = self.config.video_src
|
self.video_srcs = cycle(self.config.video_src)
|
||||||
|
|
||||||
|
|
||||||
def emit_video(self):
|
def emit_video(self):
|
||||||
|
@ -103,7 +51,6 @@ class FrameEmitter:
|
||||||
logger.info(f"Emit frames at {fps} fps")
|
logger.info(f"Emit frames at {fps} fps")
|
||||||
|
|
||||||
prev_time = time.time()
|
prev_time = time.time()
|
||||||
i = 0
|
|
||||||
while self.is_running.is_set():
|
while self.is_running.is_set():
|
||||||
ret, img = video.read()
|
ret, img = video.read()
|
||||||
|
|
||||||
|
@ -114,13 +61,7 @@ class FrameEmitter:
|
||||||
# video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
# video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||||
# ret, img = video.read()
|
# ret, img = video.read()
|
||||||
# assert ret is not False # not really error proof...
|
# assert ret is not False # not really error proof...
|
||||||
|
frame = Frame(img=img)
|
||||||
|
|
||||||
if "DATASETS/hof/" in str(video_path):
|
|
||||||
# hack to mask out area
|
|
||||||
cv2.rectangle(img, (0,0), (800,200), (0,0,0), -1)
|
|
||||||
|
|
||||||
frame = Frame(index=i, img=img)
|
|
||||||
# TODO: this is very dirty, need to find another way.
|
# TODO: this is very dirty, need to find another way.
|
||||||
# perhaps multiprocessing Array?
|
# perhaps multiprocessing Array?
|
||||||
self.frame_sock.send(pickle.dumps(frame))
|
self.frame_sock.send(pickle.dumps(frame))
|
||||||
|
@ -134,13 +75,10 @@ class FrameEmitter:
|
||||||
else:
|
else:
|
||||||
prev_time = new_frame_time
|
prev_time = new_frame_time
|
||||||
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if not self.is_running.is_set():
|
if not self.is_running.is_set():
|
||||||
# if not running, also break out of infinite generator loop
|
# if not running, also break out of infinite generator loop
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
logger.info("Stopping")
|
logger.info("Stopping")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ def start():
|
||||||
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
||||||
]
|
]
|
||||||
|
|
||||||
if args.render_file or args.render_url:
|
if args.render_preview:
|
||||||
procs.append(
|
procs.append(
|
||||||
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer')
|
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer')
|
||||||
)
|
)
|
||||||
|
|
|
@ -27,7 +27,6 @@ import matplotlib.pyplot as plt
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from trap.frame_emitter import Frame
|
from trap.frame_emitter import Frame
|
||||||
from trap.tracker import Track
|
|
||||||
|
|
||||||
logger = logging.getLogger("trap.prediction")
|
logger = logging.getLogger("trap.prediction")
|
||||||
|
|
||||||
|
@ -243,38 +242,33 @@ class PredictionServer:
|
||||||
if self.config.predict_training_data:
|
if self.config.predict_training_data:
|
||||||
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
||||||
else:
|
else:
|
||||||
zmq_ev = self.trajectory_socket.poll(timeout=3)
|
|
||||||
if not zmq_ev:
|
|
||||||
# on no data loop so that is_running is checked
|
|
||||||
continue
|
|
||||||
|
|
||||||
data = self.trajectory_socket.recv()
|
data = self.trajectory_socket.recv()
|
||||||
frame: Frame = pickle.loads(data)
|
frame: Frame = pickle.loads(data)
|
||||||
# trajectory_data = {t.track_id: t.get_projected_history_as_dict(frame.H) for t in frame.tracks.values()}
|
trajectory_data = frame.trajectories # TODO: properly refractor
|
||||||
# trajectory_data = json.loads(data)
|
# trajectory_data = json.loads(data)
|
||||||
logger.debug(f"Receive {frame.index}")
|
logger.debug(f"Receive {trajectory_data}")
|
||||||
|
|
||||||
# class FakeNode:
|
# class FakeNode:
|
||||||
# def __init__(self, node_type: NodeType):
|
# def __init__(self, node_type: NodeType):
|
||||||
# self.type = node_type
|
# self.type = node_type
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {}
|
||||||
for identifier, track in frame.tracks.items():
|
for identifier, trajectory in trajectory_data.items():
|
||||||
# if len(trajectory['history']) < 7:
|
# if len(trajectory['history']) < 7:
|
||||||
# # TODO: these trajectories should still be in the output, but without predictions
|
# # TODO: these trajectories should still be in the output, but without predictions
|
||||||
# continue
|
# continue
|
||||||
|
|
||||||
# TODO: modify this into a mapping function between JS data an the expected Node format
|
# TODO: modify this into a mapping function between JS data an the expected Node format
|
||||||
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
|
||||||
history = [[h['x'], h['y']] for h in track.get_projected_history_as_dict(frame.H)]
|
history = [[h['x'], h['y']] for h in trajectory['history']]
|
||||||
history = np.array(history)
|
history = np.array(history)
|
||||||
x = history[:, 0]
|
x = history[:, 0]
|
||||||
y = history[:, 1]
|
y = history[:, 1]
|
||||||
# TODO: calculate dt based on input
|
# TODO: calculate dt based on input
|
||||||
vx = derivative_of(x, 0.1) #eval_scene.dt
|
vx = derivative_of(x, 0.2) #eval_scene.dt
|
||||||
vy = derivative_of(y, 0.1)
|
vy = derivative_of(y, 0.2)
|
||||||
ax = derivative_of(vx, 0.1)
|
ax = derivative_of(vx, 0.2)
|
||||||
ay = derivative_of(vy, 0.1)
|
ay = derivative_of(vy, 0.2)
|
||||||
|
|
||||||
data_dict = {('position', 'x'): x[:],
|
data_dict = {('position', 'x'): x[:],
|
||||||
('position', 'y'): y[:],
|
('position', 'y'): y[:],
|
||||||
|
@ -302,7 +296,7 @@ class PredictionServer:
|
||||||
# And want to update the network
|
# And want to update the network
|
||||||
|
|
||||||
data = json.dumps({})
|
data = json.dumps({})
|
||||||
self.prediction_socket.send_pyobj(frame)
|
self.prediction_socket.send_string(data)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -326,7 +320,7 @@ class PredictionServer:
|
||||||
dists, preds = trajectron.incremental_forward(input_dict,
|
dists, preds = trajectron.incremental_forward(input_dict,
|
||||||
maps,
|
maps,
|
||||||
prediction_horizon=25, # TODO: make variable
|
prediction_horizon=25, # TODO: make variable
|
||||||
num_samples=5, # TODO: make variable
|
num_samples=20, # TODO: make variable
|
||||||
robot_present_and_future=robot_present_and_future,
|
robot_present_and_future=robot_present_and_future,
|
||||||
full_dist=True)
|
full_dist=True)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
|
@ -344,7 +338,6 @@ class PredictionServer:
|
||||||
# prediction_dict provides the actual predictions
|
# prediction_dict provides the actual predictions
|
||||||
# histories_dict provides the trajectory used for prediction
|
# histories_dict provides the trajectory used for prediction
|
||||||
# futures_dict is the Ground Truth, which is unvailable in an online setting
|
# futures_dict is the Ground Truth, which is unvailable in an online setting
|
||||||
|
|
||||||
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
|
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
|
||||||
eval_scene.dt,
|
eval_scene.dt,
|
||||||
hyperparams['maximum_history_length'],
|
hyperparams['maximum_history_length'],
|
||||||
|
@ -365,29 +358,24 @@ class PredictionServer:
|
||||||
|
|
||||||
for node in histories_dict:
|
for node in histories_dict:
|
||||||
history = histories_dict[node]
|
history = histories_dict[node]
|
||||||
# future = futures_dict[node] # ground truth dict
|
# future = futures_dict[node]
|
||||||
predictions = prediction_dict[node]
|
predictions = prediction_dict[node]
|
||||||
|
|
||||||
if not len(history) or np.isnan(history[-1]).any():
|
if not len(history) or np.isnan(history[-1]).any():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# response[node.id] = {
|
response[node.id] = {
|
||||||
# 'id': node.id,
|
'id': node.id,
|
||||||
# 'det_conf': trajectory_data[node.id]['det_conf'],
|
'history': history.tolist(),
|
||||||
# 'bbox': trajectory_data[node.id]['bbox'],
|
'predictions': predictions[0].tolist() # use batch 0
|
||||||
# 'history': history.tolist(),
|
}
|
||||||
# 'predictions': predictions[0].tolist() # use batch 0
|
|
||||||
# }
|
|
||||||
|
|
||||||
frame.tracks[node.id].predictor_history = history.tolist()
|
data = json.dumps(response)
|
||||||
frame.tracks[node.id].predictions = predictions[0].tolist() # use batch 0
|
|
||||||
|
|
||||||
# data = json.dumps(response)
|
|
||||||
if self.config.predict_training_data:
|
if self.config.predict_training_data:
|
||||||
logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s")
|
logger.info(f"Frame prediction: {len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
||||||
self.prediction_socket.send_pyobj(frame)
|
self.prediction_socket.send_string(data)
|
||||||
logger.info('Stopping')
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
||||||
|
|
108
trap/renderer.py
108
trap/renderer.py
|
@ -1,4 +1,4 @@
|
||||||
import ffmpeg
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
|
@ -33,65 +33,27 @@ class Renderer:
|
||||||
|
|
||||||
self.inv_H = np.linalg.pinv(self.H)
|
self.inv_H = np.linalg.pinv(self.H)
|
||||||
|
|
||||||
# TODO: get FPS from frame_emitter
|
|
||||||
# self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
|
||||||
self.fps = 10
|
|
||||||
self.frame_size = (1280,720)
|
|
||||||
self.out_writer = self.start_writer() if self.config.render_file else None
|
|
||||||
self.streaming_process = self.start_streaming() if self.config.render_url else None
|
|
||||||
|
|
||||||
def start_writer(self):
|
|
||||||
if not self.config.output_dir.exists():
|
if not self.config.output_dir.exists():
|
||||||
raise FileNotFoundError("Path does not exist")
|
raise FileNotFoundError("Path does not exist")
|
||||||
|
|
||||||
date_str = datetime.datetime.now().isoformat(timespec="minutes")
|
date_str = datetime.datetime.now().isoformat(timespec="minutes")
|
||||||
filename = self.config.output_dir / f"render_predictions-{date_str}-{self.config.detector}.mp4"
|
filename = self.config.output_dir / f"render_predictions-{date_str}.mp4"
|
||||||
logger.info(f"Write to {filename}")
|
logger.info(f"Write to {filename}")
|
||||||
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'vp09')
|
fourcc = cv2.VideoWriter_fourcc(*'vp09')
|
||||||
|
# TODO: get FPS from frame_emitter
|
||||||
return cv2.VideoWriter(str(filename), fourcc, self.fps, self.frame_size)
|
self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720))
|
||||||
|
|
||||||
def start_streaming(self):
|
|
||||||
return (
|
|
||||||
ffmpeg
|
|
||||||
.input('pipe:', format='rawvideo',codec="rawvideo", pix_fmt='bgr24', s='{}x{}'.format(*self.frame_size))
|
|
||||||
.output(
|
|
||||||
self.config.render_url,
|
|
||||||
#codec = "copy", # use same codecs of the original video
|
|
||||||
codec='libx264',
|
|
||||||
listen=1, # enables HTTP server
|
|
||||||
pix_fmt="yuv420p",
|
|
||||||
preset="ultrafast",
|
|
||||||
tune="zerolatency",
|
|
||||||
g=f"{self.fps*2}",
|
|
||||||
analyzeduration="2000000",
|
|
||||||
probesize="1000000",
|
|
||||||
f='mpegts'
|
|
||||||
)
|
|
||||||
.overwrite_output()
|
|
||||||
.run_async(pipe_stdin=True)
|
|
||||||
)
|
|
||||||
# return process
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
prediction_frame = None
|
predictions = {}
|
||||||
i=0
|
i=0
|
||||||
first_time = None
|
first_time = None
|
||||||
while self.is_running.is_set():
|
while self.is_running.is_set():
|
||||||
i+=1
|
i+=1
|
||||||
|
|
||||||
zmq_ev = self.frame_sock.poll(timeout=3)
|
|
||||||
if not zmq_ev:
|
|
||||||
# when no data comes in, loop so that is_running is checked
|
|
||||||
continue
|
|
||||||
|
|
||||||
frame: Frame = self.frame_sock.recv_pyobj()
|
frame: Frame = self.frame_sock.recv_pyobj()
|
||||||
try:
|
try:
|
||||||
prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
logger.debug(f'reuse prediction')
|
logger.debug(f'reuse prediction')
|
||||||
|
|
||||||
|
@ -106,77 +68,47 @@ class Renderer:
|
||||||
# warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000))
|
# warpedFrame = cv2.warpPerspective(img, new_H, (1000,1000))
|
||||||
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
|
# cv2.imwrite(str(self.config.output_dir / "orig.png"), warpedFrame)
|
||||||
|
|
||||||
if not prediction_frame:
|
|
||||||
cv2.putText(img, f"Waiting for prediction...", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
for track_id, prediction in predictions.items():
|
||||||
continue
|
if not 'history' in prediction or not len(prediction['history']):
|
||||||
else:
|
|
||||||
for track_id, track in prediction_frame.tracks.items():
|
|
||||||
if not len(track.history):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||||
coords = [d.get_foot_coords() for d in track.history]
|
|
||||||
# logger.warning(f"{coords=}")
|
# logger.warning(f"{coords=}")
|
||||||
|
center = [int(p) for p in coords[-1]]
|
||||||
|
cv2.circle(img, center, 5, (0,255,0))
|
||||||
|
cv2.putText(img, track_id, (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.8, color=(0,255,0))
|
||||||
|
|
||||||
for ci in range(1, len(coords)):
|
for ci in range(1, len(coords)):
|
||||||
start = [int(p) for p in coords[ci-1]]
|
start = [int(p) for p in coords[ci-1]]
|
||||||
end = [int(p) for p in coords[ci]]
|
end = [int(p) for p in coords[ci]]
|
||||||
cv2.line(img, start, end, (255,255,255), 2, lineType=cv2.LINE_AA)
|
cv2.line(img, start, end, (255,255,255), 2)
|
||||||
|
|
||||||
if not track.predictions or not len(track.predictions):
|
if not 'predictions' in prediction or not len(prediction['predictions']):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for pred_i, pred in enumerate(track.predictions):
|
for pred in prediction['predictions']:
|
||||||
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||||
color = (0,0,255) if pred_i == 1 else (100,100,100)
|
|
||||||
for ci in range(1, len(pred_coords)):
|
for ci in range(1, len(pred_coords)):
|
||||||
start = [int(p) for p in pred_coords[ci-1]]
|
start = [int(p) for p in pred_coords[ci-1]]
|
||||||
end = [int(p) for p in pred_coords[ci]]
|
end = [int(p) for p in pred_coords[ci]]
|
||||||
cv2.line(img, start, end, color, 1, lineType=cv2.LINE_AA)
|
cv2.line(img, start, end, (0,0,255), 1)
|
||||||
|
|
||||||
for track_id, track in prediction_frame.tracks.items():
|
|
||||||
# draw tracker marker and track id last so it lies over the trajectories
|
|
||||||
# this goes is a second loop so it overlays over _all_ trajectories
|
|
||||||
# coords = cv2.perspectiveTransform(np.array([[track.history[-1].get_foot_coords()]]), self.inv_H)[0]
|
|
||||||
coords = track.history[-1].get_foot_coords()
|
|
||||||
|
|
||||||
center = [int(p) for p in coords]
|
|
||||||
cv2.circle(img, center, 5, (0,255,0))
|
|
||||||
(l, t, r, b) = track.history[-1].to_ltrb()
|
|
||||||
p1 = (l, t)
|
|
||||||
p2 = (r, b)
|
|
||||||
cv2.rectangle(img, p1, p2, (255,0,0), 1)
|
|
||||||
cv2.putText(img, f"{track_id} ({(track.history[-1].conf or 0):.2f})", (center[0]+8, center[1]), cv2.FONT_HERSHEY_SIMPLEX, fontScale=.7, thickness=2, color=(0,255,0), lineType=cv2.LINE_AA)
|
|
||||||
|
|
||||||
if first_time is None:
|
if first_time is None:
|
||||||
first_time = frame.time
|
first_time = frame.time
|
||||||
|
|
||||||
cv2.putText(img, f"{frame.index:06d}", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||||
cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
|
||||||
|
|
||||||
if prediction_frame:
|
|
||||||
cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,50), cv2.FONT_HERSHEY_PLAIN, 1, (0,0,255), 1)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
||||||
|
|
||||||
# cv2.imwrite(str(img_path), img)
|
# cv2.imwrite(str(img_path), img)
|
||||||
logger.info(f"write frame {frame.time - first_time:.3f}s")
|
self.out.write(img)
|
||||||
if self.out_writer:
|
|
||||||
self.out_writer.write(img)
|
|
||||||
if self.streaming_process:
|
|
||||||
self.streaming_process.stdin.write(img.tobytes())
|
|
||||||
logger.info('Stopping')
|
logger.info('Stopping')
|
||||||
|
|
||||||
if i>2:
|
if i>2:
|
||||||
if self.streaming_process:
|
self.out.release()
|
||||||
self.streaming_process.stdin.close()
|
|
||||||
if self.out_writer:
|
|
||||||
self.out_writer.release()
|
|
||||||
if self.streaming_process:
|
|
||||||
# oddly wrapped, because both close and release() take time.
|
|
||||||
self.streaming_process.wait()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import asyncio
|
import asyncio
|
||||||
import errno
|
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Event
|
from multiprocessing import Event
|
||||||
import subprocess
|
|
||||||
from typing import Set, Union, Dict, Any
|
from typing import Set, Union, Dict, Any
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
@ -147,14 +145,7 @@ class WsRouter:
|
||||||
|
|
||||||
# loop = tornado.ioloop.IOLoop.current()
|
# loop = tornado.ioloop.IOLoop.current()
|
||||||
logger.info(f"Listen on {self.config.ws_port}")
|
logger.info(f"Listen on {self.config.ws_port}")
|
||||||
try:
|
|
||||||
self.application.listen(self.config.ws_port)
|
self.application.listen(self.config.ws_port)
|
||||||
except OSError as e:
|
|
||||||
if e.errno == errno.EADDRINUSE:
|
|
||||||
logger.critical("Address already in use by process")
|
|
||||||
subprocess.run(["lsof", "-i", f"tcp:{self.config.ws_port:d}"])
|
|
||||||
raise
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
task = self.evt_loop.create_task(self.prediction_forwarder())
|
task = self.evt_loop.create_task(self.prediction_forwarder())
|
||||||
|
@ -165,9 +156,6 @@ class WsRouter:
|
||||||
async def prediction_forwarder(self):
|
async def prediction_forwarder(self):
|
||||||
logger.info("Starting prediction forwarder")
|
logger.info("Starting prediction forwarder")
|
||||||
while self.is_running.is_set():
|
while self.is_running.is_set():
|
||||||
# timeout so that if no events occur, loop can still stop on is_running
|
|
||||||
has_event = await self.prediction_socket.poll(timeout=1)
|
|
||||||
if has_event:
|
|
||||||
msg = await self.prediction_socket.recv_string()
|
msg = await self.prediction_socket.recv_string()
|
||||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||||
WebSocketPredictionHandler.write_to_clients(msg)
|
WebSocketPredictionHandler.write_to_clients(msg)
|
||||||
|
|
148
trap/tracker.py
148
trap/tracker.py
|
@ -8,24 +8,22 @@ from multiprocessing import Event
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights, maskrcnn_resnet50_fpn_v2, MaskRCNN_ResNet50_FPN_V2_Weights
|
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights, keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
|
||||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||||
from torchvision.models import ResNet50_Weights
|
|
||||||
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
from deep_sort_realtime.deep_sort.track import Track as DeepsortTrack
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from ultralytics.engine.results import Results as YOLOResult
|
from ultralytics.engine.results import Results as YOLOResult
|
||||||
|
|
||||||
from trap.frame_emitter import Frame, Detection, Track
|
from trap.frame_emitter import Frame
|
||||||
|
|
||||||
# Detection = [int, int, int, int, float, int]
|
Detection = [int, int, int, int, float, int]
|
||||||
# Detections = [Detection]
|
Detections = [Detection]
|
||||||
|
|
||||||
# This is the dt that is also used by the scene.
|
# This is the dt that is also used by the scene.
|
||||||
# as this needs to be rather stable, try to adhere
|
# as this needs to be rather stable, try to adhere
|
||||||
|
@ -35,13 +33,30 @@ TARGET_DT = .1
|
||||||
|
|
||||||
logger = logging.getLogger("trap.tracker")
|
logger = logging.getLogger("trap.tracker")
|
||||||
|
|
||||||
DETECTOR_RETINANET = 'retinanet'
|
DETECTOR_RESNET = 'resnet'
|
||||||
DETECTOR_MASKRCNN = 'maskrcnn'
|
|
||||||
DETECTOR_FASTERRCNN = 'fasterrcnn'
|
|
||||||
DETECTOR_YOLOv8 = 'ultralytics'
|
DETECTOR_YOLOv8 = 'ultralytics'
|
||||||
|
|
||||||
DETECTORS = [DETECTOR_RETINANET, DETECTOR_MASKRCNN, DETECTOR_FASTERRCNN, DETECTOR_YOLOv8]
|
DETECTORS = [DETECTOR_RESNET, DETECTOR_YOLOv8]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Track:
|
||||||
|
track_id: str = None
|
||||||
|
history: [Detection]= field(default_factory=lambda: [])
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Detection:
|
||||||
|
track_id: str
|
||||||
|
l: int # left
|
||||||
|
t: int # top
|
||||||
|
w: int # width
|
||||||
|
h: int # height
|
||||||
|
|
||||||
|
def get_foot_coords(self):
|
||||||
|
return [self.l + 0.5 * self.w, self.t+self.h]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_deepsort(cls, dstrack: DeepsortTrack):
|
||||||
|
return cls(dstrack.track_id, *dstrack.to_ltwh())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,40 +81,27 @@ class Tracker:
|
||||||
# TODO: support removal
|
# TODO: support removal
|
||||||
self.tracks = defaultdict(lambda: Track())
|
self.tracks = defaultdict(lambda: Track())
|
||||||
|
|
||||||
if self.config.detector == DETECTOR_RETINANET:
|
if self.config.detector == DETECTOR_RESNET:
|
||||||
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
# weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
|
||||||
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
# self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.2)
|
||||||
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
|
||||||
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.35)
|
self.model = keypointrcnn_resnet50_fpn(weights=weights, box_score_thresh=0.20)
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
# Put the model in inference mode
|
# Put the model in inference mode
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
# Get the transforms for the model's weights
|
# Get the transforms for the model's weights
|
||||||
self.preprocess = weights.transforms().to(self.device)
|
self.preprocess = weights.transforms().to(self.device)
|
||||||
self.mot_tracker = DeepSort(max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
|
||||||
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
|
||||||
)
|
|
||||||
elif self.config.detector == DETECTOR_MASKRCNN:
|
|
||||||
weights = MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1
|
|
||||||
self.model = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)
|
|
||||||
self.model.to(self.device)
|
|
||||||
# Put the model in inference mode
|
|
||||||
self.model.eval()
|
|
||||||
# Get the transforms for the model's weights
|
|
||||||
self.preprocess = weights.transforms().to(self.device)
|
|
||||||
self.mot_tracker = DeepSort(n_init=5, max_iou_distance=1, max_cosine_distance=0.5, max_age=12, nms_max_overlap=0.9,
|
|
||||||
# embedder='torchreid', embedder_wts="../MODELS/osnet_x1_0_imagenet.pth"
|
|
||||||
)
|
|
||||||
elif self.config.detector == DETECTOR_YOLOv8:
|
elif self.config.detector == DETECTOR_YOLOv8:
|
||||||
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
self.model = YOLO('EXPERIMENTS/yolov8x.pt')
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"{self.config.detector} is not implemented yet. See --help")
|
raise RuntimeError("No valid detector specified. See --help")
|
||||||
|
|
||||||
|
|
||||||
# homography = list(source.glob('*img2world.txt'))[0]
|
# homography = list(source.glob('*img2world.txt'))[0]
|
||||||
|
|
||||||
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
self.H = np.loadtxt(self.config.homography, delimiter=',')
|
||||||
|
|
||||||
|
self.mot_tracker = DeepSort(max_age=30, nms_max_overlap=0.9)
|
||||||
logger.debug("Set up tracker")
|
logger.debug("Set up tracker")
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,29 +124,17 @@ class Tracker:
|
||||||
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
# following https://github.com/StanfordASL/Trajectron-plus-plus/blob/master/experiments/pedestrians/process_data.py
|
||||||
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
training_csv = csv.DictWriter(training_fp, fieldnames=['frame_id', 'track_id', 'x', 'y'], delimiter='\t', quoting=csv.QUOTE_NONE)
|
||||||
|
|
||||||
prev_frame_i = -1
|
frame_i = 0
|
||||||
|
|
||||||
while self.is_running.is_set():
|
while self.is_running.is_set():
|
||||||
# this waiting for target_dt causes frame loss. E.g. with target_dt at .1, it
|
this_run_time = time.time()
|
||||||
# skips exactly 1 frame on a 10 fps video (which, it obviously should not do)
|
# logger.debug(f'test {prev_run_time - this_run_time}')
|
||||||
# so for now, timing should move to emitter
|
time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
||||||
# this_run_time = time.time()
|
prev_run_time = time.time()
|
||||||
# # logger.debug(f'test {prev_run_time - this_run_time}')
|
|
||||||
# time.sleep(max(0, prev_run_time - this_run_time + TARGET_DT))
|
|
||||||
# prev_run_time = time.time()
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
frame: Frame = self.frame_sock.recv_pyobj() # frame delivery in current setup: 0.012-0.03s
|
|
||||||
|
|
||||||
if frame.index > (prev_frame_i+1):
|
|
||||||
logger.warn(f"Dropped {frame.index - prev_frame_i - 1} frames ({frame.index=}, {prev_frame_i=})")
|
|
||||||
|
|
||||||
|
|
||||||
prev_frame_i = frame.index
|
|
||||||
# load homography into frame (TODO: should this be done in emitter?)
|
|
||||||
frame.H = self.H
|
|
||||||
|
|
||||||
|
msg = self.frame_sock.recv()
|
||||||
|
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
|
||||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
if self.config.detector == DETECTOR_YOLOv8:
|
if self.config.detector == DETECTOR_YOLOv8:
|
||||||
|
@ -154,43 +144,37 @@ class Tracker:
|
||||||
|
|
||||||
|
|
||||||
# Store detections into tracklets
|
# Store detections into tracklets
|
||||||
projected_coordinates = []
|
|
||||||
for detection in detections:
|
for detection in detections:
|
||||||
track = self.tracks[detection.track_id]
|
track = self.tracks[detection.track_id]
|
||||||
track.track_id = detection.track_id # for new tracks
|
track.track_id = detection.track_id # for new tracks
|
||||||
|
|
||||||
track.history.append(detection) # add to history
|
track.history.append(detection)
|
||||||
# projected_coordinates.append(track.get_projected_history(self.H)) # then get full history
|
|
||||||
|
|
||||||
# TODO: hadle occlusions, and dissappearance
|
|
||||||
# if len(track.history) > 30: # retain 90 tracks for 90 frames
|
# if len(track.history) > 30: # retain 90 tracks for 90 frames
|
||||||
# track.history.pop(0)
|
# track.history.pop(0)
|
||||||
|
|
||||||
|
foot_coordinates = np.array([[t.get_foot_coords()] for t in detections])
|
||||||
|
if len(foot_coordinates):
|
||||||
|
projected_coordinates = cv2.perspectiveTransform(foot_coordinates,self.H)
|
||||||
|
else:
|
||||||
|
projected_coordinates = []
|
||||||
|
|
||||||
# trajectories = {}
|
# print(TEMP_proj_coords)
|
||||||
# for detection in detections:
|
trajectories = {}
|
||||||
# tid = str(detection.track_id)
|
for detection, coords in zip(detections, projected_coordinates):
|
||||||
# track = self.tracks[detection.track_id]
|
tid = str(detection.track_id)
|
||||||
# coords = track.get_projected_history(self.H) # get full history
|
trajectories[tid] = {
|
||||||
# trajectories[tid] = {
|
"id": tid,
|
||||||
# "id": tid,
|
"history": [{"x":c[0], "y":c[1]} for c in coords] if not self.config.bypass_prediction else coords.tolist() # already doubles nested, fine for test
|
||||||
# "det_conf": detection.conf,
|
}
|
||||||
# "bbox": detection.to_ltwh(),
|
|
||||||
# "history": [{"x":c[0], "y":c[1]} for c in coords[0]] if not self.config.bypass_prediction else coords[0].tolist() # already doubles nested, fine for test
|
|
||||||
# }
|
|
||||||
active_track_ids = [d.track_id for d in detections]
|
|
||||||
active_tracks = {t.track_id: t for t in self.tracks.values() if t.track_id in active_track_ids}
|
|
||||||
# logger.info(f"{trajectories}")
|
# logger.info(f"{trajectories}")
|
||||||
frame.tracks = active_tracks
|
frame.trajectories = trajectories
|
||||||
|
|
||||||
# if self.config.bypass_prediction:
|
|
||||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
|
||||||
# else:
|
|
||||||
# self.trajectory_socket.send(pickle.dumps(frame))
|
|
||||||
self.trajectory_socket.send_pyobj(frame)
|
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
logger.debug(f"Trajectories: {len(active_tracks)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||||
|
if self.config.bypass_prediction:
|
||||||
|
self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
|
else:
|
||||||
|
self.trajectory_socket.send(pickle.dumps(frame))
|
||||||
|
|
||||||
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||||
|
@ -200,14 +184,13 @@ class Tracker:
|
||||||
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
||||||
if training_csv:
|
if training_csv:
|
||||||
training_csv.writerows([{
|
training_csv.writerows([{
|
||||||
'frame_id': round(frame.index * 10., 1), # not really time
|
'frame_id': round(frame_i * 10., 1), # not really time
|
||||||
'track_id': t['id'],
|
'track_id': t['id'],
|
||||||
'x': t['history'][-1]['x' if not self.config.bypass_prediction else 0],
|
'x': t['history'][-1]['x'],
|
||||||
'y': t['history'][-1]['y' if not self.config.bypass_prediction else 1],
|
'y': t['history'][-1]['y'],
|
||||||
} for t in active_tracks.values()])
|
} for t in trajectories.values()])
|
||||||
training_frames += len(active_tracks)
|
training_frames += len(trajectories)
|
||||||
# print(time.time() - start_time)
|
frame_i += 1
|
||||||
|
|
||||||
|
|
||||||
if training_fp:
|
if training_fp:
|
||||||
training_fp.close()
|
training_fp.close()
|
||||||
|
@ -231,9 +214,6 @@ class Tracker:
|
||||||
|
|
||||||
def _yolov8_track(self, img) -> [Detection]:
|
def _yolov8_track(self, img) -> [Detection]:
|
||||||
results: [YOLOResult] = self.model.track(img, persist=True)
|
results: [YOLOResult] = self.model.track(img, persist=True)
|
||||||
if results[0].boxes is None or results[0].boxes.id is None:
|
|
||||||
# work around https://github.com/ultralytics/ultralytics/issues/5968
|
|
||||||
return []
|
|
||||||
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
return [Detection(track_id, *bbox) for bbox, track_id in zip(results[0].boxes.xywh.cpu(), results[0].boxes.id.int().cpu().tolist())]
|
||||||
|
|
||||||
def _resnet_track(self, img) -> [Detection]:
|
def _resnet_track(self, img) -> [Detection]:
|
||||||
|
@ -241,7 +221,7 @@ class Tracker:
|
||||||
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
tracks: [DeepsortTrack] = self.mot_tracker.update_tracks(detections, frame=img)
|
||||||
return [Detection.from_deepsort(t) for t in tracks]
|
return [Detection.from_deepsort(t) for t in tracks]
|
||||||
|
|
||||||
def _resnet_detect_persons(self, frame) -> [Detection]:
|
def _resnet_detect_persons(self, frame) -> Detections:
|
||||||
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||||
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
|
||||||
t = t.permute(2, 0, 1)
|
t = t.permute(2, 0, 1)
|
||||||
|
|
Loading…
Reference in a new issue