Split tracker in renderer
This commit is contained in:
parent
6e98138cc2
commit
c3263e7448
1 changed files with 53 additions and 20 deletions
|
@ -58,6 +58,8 @@ def relativePolarToPoint(origin, r, angle) -> tuple[float, float]:
|
||||||
|
|
||||||
class DrawnTrack:
|
class DrawnTrack:
|
||||||
def __init__(self, track_id, track: Track, renderer: Renderer, H):
|
def __init__(self, track_id, track: Track, renderer: Renderer, H):
|
||||||
|
# self.created_at = time.time()
|
||||||
|
self.update_at = self.created_at = time.time()
|
||||||
self.track_id = track_id
|
self.track_id = track_id
|
||||||
self.renderer = renderer
|
self.renderer = renderer
|
||||||
self.set_track(track, H)
|
self.set_track(track, H)
|
||||||
|
@ -67,6 +69,8 @@ class DrawnTrack:
|
||||||
self.pred_shapes: list[list[pyglet.shapes.Line]] = []
|
self.pred_shapes: list[list[pyglet.shapes.Line]] = []
|
||||||
|
|
||||||
def set_track(self, track: Track, H):
|
def set_track(self, track: Track, H):
|
||||||
|
self.update_at = time.time()
|
||||||
|
|
||||||
self.track = track
|
self.track = track
|
||||||
self.H = H
|
self.H = H
|
||||||
self.coords = [d.get_foot_coords() for d in track.history]
|
self.coords = [d.get_foot_coords() for d in track.history]
|
||||||
|
@ -239,6 +243,11 @@ class Renderer:
|
||||||
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
|
self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
|
||||||
|
|
||||||
|
self.tracker_sock = context.socket(zmq.SUB)
|
||||||
|
self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
|
self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
|
self.tracker_sock.connect(config.zmq_trajectory_addr)
|
||||||
|
|
||||||
self.frame_sock = context.socket(zmq.SUB)
|
self.frame_sock = context.socket(zmq.SUB)
|
||||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
|
@ -282,6 +291,7 @@ class Renderer:
|
||||||
|
|
||||||
self.first_time: float|None = None
|
self.first_time: float|None = None
|
||||||
self.frame: Frame|None= None
|
self.frame: Frame|None= None
|
||||||
|
self.tracker_frame: Frame|None = None
|
||||||
self.prediction_frame: Frame|None = None
|
self.prediction_frame: Frame|None = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -341,8 +351,9 @@ class Renderer:
|
||||||
|
|
||||||
def init_labels(self):
|
def init_labels(self):
|
||||||
base_color = (255,)*4
|
base_color = (255,)*4
|
||||||
info_color = (255,255,0, 255)
|
color_predictor = (255,255,0, 255)
|
||||||
info2_color = (255,0, 255, 255)
|
color_info = (255,0, 255, 255)
|
||||||
|
color_tracker = (0,255, 255, 255)
|
||||||
|
|
||||||
options = []
|
options = []
|
||||||
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
|
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
|
||||||
|
@ -351,11 +362,13 @@ class Renderer:
|
||||||
self.labels = {
|
self.labels = {
|
||||||
'waiting': pyglet.text.Label("Waiting for prediction"),
|
'waiting': pyglet.text.Label("Waiting for prediction"),
|
||||||
'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||||
'frame_time': pyglet.text.Label("t", x=120, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
'tracker_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||||
'pred_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=info_color, batch=self.batch_overlay),
|
'pred_idx': pyglet.text.Label("", x=110, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
|
||||||
'frame_latency': pyglet.text.Label("", x=200, y=self.window.height - 17, color=info2_color, batch=self.batch_overlay),
|
'frame_time': pyglet.text.Label("t", x=140, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||||
'pred_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=info_color, batch=self.batch_overlay),
|
'frame_latency': pyglet.text.Label("", x=235, y=self.window.height - 17, color=color_info, batch=self.batch_overlay),
|
||||||
'track_len': pyglet.text.Label("", x=500, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
'tracker_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||||
|
'pred_time': pyglet.text.Label("", x=360, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
|
||||||
|
'track_len': pyglet.text.Label("", x=800, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||||
'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay),
|
'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay),
|
||||||
'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay),
|
'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay),
|
||||||
}
|
}
|
||||||
|
@ -368,10 +381,15 @@ class Renderer:
|
||||||
self.labels['frame_time'].text = f"{self.frame.time - self.first_time: >10.2f}s"
|
self.labels['frame_time'].text = f"{self.frame.time - self.first_time: >10.2f}s"
|
||||||
self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s"
|
self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s"
|
||||||
|
|
||||||
|
if self.tracker_frame:
|
||||||
|
self.labels['tracker_idx'].text = f"{self.tracker_frame.index - self.frame.index}"
|
||||||
|
self.labels['tracker_time'].text = f"{self.tracker_frame.time - time.time():.3f}s"
|
||||||
|
self.labels['track_len'].text = f"{len(self.tracker_frame.tracks)} tracks"
|
||||||
|
|
||||||
if self.prediction_frame:
|
if self.prediction_frame:
|
||||||
self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}"
|
self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}"
|
||||||
self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s"
|
self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s"
|
||||||
self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
|
# self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
|
||||||
|
|
||||||
|
|
||||||
# cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
# cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||||
|
@ -397,6 +415,7 @@ class Renderer:
|
||||||
|
|
||||||
|
|
||||||
def check_frames(self, dt):
|
def check_frames(self, dt):
|
||||||
|
new_tracks = False
|
||||||
try:
|
try:
|
||||||
self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
|
self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
|
||||||
if not self.first_time:
|
if not self.first_time:
|
||||||
|
@ -413,28 +432,41 @@ class Renderer:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||||
self.update_tracks()
|
new_tracks = True
|
||||||
except zmq.ZMQError as e:
|
except zmq.ZMQError as e:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK)
|
||||||
|
new_tracks = True
|
||||||
|
except zmq.ZMQError as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if new_tracks:
|
||||||
|
self.update_tracks()
|
||||||
|
|
||||||
def update_tracks(self):
|
def update_tracks(self):
|
||||||
"""Updates the track objects and shapes. Called after setting `prediction_frame`
|
"""Updates the track objects and shapes. Called after setting `prediction_frame`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# clean up
|
# clean up
|
||||||
for track_id in list(self.drawn_tracks.keys()):
|
# for track_id in list(self.drawn_tracks.keys()):
|
||||||
if track_id not in self.prediction_frame.tracks.keys():
|
# if track_id not in self.prediction_frame.tracks.keys():
|
||||||
|
# # TODO fade out
|
||||||
|
# del self.drawn_tracks[track_id]
|
||||||
|
|
||||||
|
if self.prediction_frame:
|
||||||
|
for track_id, track in self.prediction_frame.tracks.items():
|
||||||
|
if track_id not in self.drawn_tracks:
|
||||||
|
self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H)
|
||||||
|
else:
|
||||||
|
self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H)
|
||||||
|
|
||||||
|
# clean up
|
||||||
|
for track in self.drawn_tracks.values():
|
||||||
|
if track.update_at < time.time() - 5:
|
||||||
# TODO fade out
|
# TODO fade out
|
||||||
del self.drawn_tracks[track_id]
|
del self.drawn_tracks[track_id]
|
||||||
|
|
||||||
|
|
||||||
for track_id, track in self.prediction_frame.tracks.items():
|
|
||||||
if track_id not in self.drawn_tracks:
|
|
||||||
self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H)
|
|
||||||
else:
|
|
||||||
self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H)
|
|
||||||
|
|
||||||
|
|
||||||
def on_key_press(self, symbol, modifiers):
|
def on_key_press(self, symbol, modifiers):
|
||||||
print('A key was pressed, use f to hide')
|
print('A key was pressed, use f to hide')
|
||||||
if symbol == ord('f'):
|
if symbol == ord('f'):
|
||||||
|
@ -550,6 +582,7 @@ class Renderer:
|
||||||
def run(self):
|
def run(self):
|
||||||
frame = None
|
frame = None
|
||||||
prediction_frame = None
|
prediction_frame = None
|
||||||
|
tracker_frame = None
|
||||||
|
|
||||||
i=0
|
i=0
|
||||||
first_time = None
|
first_time = None
|
||||||
|
@ -626,7 +659,7 @@ colorset = [(0, 0, 0),
|
||||||
(255, 255, 0)
|
(255, 255, 0)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Deprecated
|
||||||
def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace) -> np.array:
|
def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace) -> np.array:
|
||||||
# TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
|
# TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
|
||||||
# or https://github.com/pygobject/pycairo?tab=readme-ov-file
|
# or https://github.com/pygobject/pycairo?tab=readme-ov-file
|
||||||
|
|
Loading…
Reference in a new issue