Split tracker in renderer
This commit is contained in:
parent
6e98138cc2
commit
c3263e7448
1 changed files with 53 additions and 20 deletions
|
@ -58,6 +58,8 @@ def relativePolarToPoint(origin, r, angle) -> tuple[float, float]:
|
|||
|
||||
class DrawnTrack:
|
||||
def __init__(self, track_id, track: Track, renderer: Renderer, H):
|
||||
# self.created_at = time.time()
|
||||
self.update_at = self.created_at = time.time()
|
||||
self.track_id = track_id
|
||||
self.renderer = renderer
|
||||
self.set_track(track, H)
|
||||
|
@ -67,6 +69,8 @@ class DrawnTrack:
|
|||
self.pred_shapes: list[list[pyglet.shapes.Line]] = []
|
||||
|
||||
def set_track(self, track: Track, H):
|
||||
self.update_at = time.time()
|
||||
|
||||
self.track = track
|
||||
self.H = H
|
||||
self.coords = [d.get_foot_coords() for d in track.history]
|
||||
|
@ -239,6 +243,11 @@ class Renderer:
|
|||
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
|
||||
|
||||
self.tracker_sock = context.socket(zmq.SUB)
|
||||
self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
self.tracker_sock.connect(config.zmq_trajectory_addr)
|
||||
|
||||
self.frame_sock = context.socket(zmq.SUB)
|
||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
|
@ -282,6 +291,7 @@ class Renderer:
|
|||
|
||||
self.first_time: float|None = None
|
||||
self.frame: Frame|None= None
|
||||
self.tracker_frame: Frame|None = None
|
||||
self.prediction_frame: Frame|None = None
|
||||
|
||||
|
||||
|
@ -341,8 +351,9 @@ class Renderer:
|
|||
|
||||
def init_labels(self):
|
||||
base_color = (255,)*4
|
||||
info_color = (255,255,0, 255)
|
||||
info2_color = (255,0, 255, 255)
|
||||
color_predictor = (255,255,0, 255)
|
||||
color_info = (255,0, 255, 255)
|
||||
color_tracker = (0,255, 255, 255)
|
||||
|
||||
options = []
|
||||
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
|
||||
|
@ -351,11 +362,13 @@ class Renderer:
|
|||
self.labels = {
|
||||
'waiting': pyglet.text.Label("Waiting for prediction"),
|
||||
'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||
'frame_time': pyglet.text.Label("t", x=120, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||
'pred_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=info_color, batch=self.batch_overlay),
|
||||
'frame_latency': pyglet.text.Label("", x=200, y=self.window.height - 17, color=info2_color, batch=self.batch_overlay),
|
||||
'pred_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=info_color, batch=self.batch_overlay),
|
||||
'track_len': pyglet.text.Label("", x=500, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||
'tracker_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||
'pred_idx': pyglet.text.Label("", x=110, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
|
||||
'frame_time': pyglet.text.Label("t", x=140, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
|
||||
'frame_latency': pyglet.text.Label("", x=235, y=self.window.height - 17, color=color_info, batch=self.batch_overlay),
|
||||
'tracker_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||
'pred_time': pyglet.text.Label("", x=360, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
|
||||
'track_len': pyglet.text.Label("", x=800, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
|
||||
'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay),
|
||||
'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay),
|
||||
}
|
||||
|
@ -368,10 +381,15 @@ class Renderer:
|
|||
self.labels['frame_time'].text = f"{self.frame.time - self.first_time: >10.2f}s"
|
||||
self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s"
|
||||
|
||||
if self.tracker_frame:
|
||||
self.labels['tracker_idx'].text = f"{self.tracker_frame.index - self.frame.index}"
|
||||
self.labels['tracker_time'].text = f"{self.tracker_frame.time - time.time():.3f}s"
|
||||
self.labels['track_len'].text = f"{len(self.tracker_frame.tracks)} tracks"
|
||||
|
||||
if self.prediction_frame:
|
||||
self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}"
|
||||
self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s"
|
||||
self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
|
||||
# self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
|
||||
|
||||
|
||||
# cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
|
||||
|
@ -397,6 +415,7 @@ class Renderer:
|
|||
|
||||
|
||||
def check_frames(self, dt):
|
||||
new_tracks = False
|
||||
try:
|
||||
self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
if not self.first_time:
|
||||
|
@ -413,28 +432,41 @@ class Renderer:
|
|||
pass
|
||||
try:
|
||||
self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
self.update_tracks()
|
||||
new_tracks = True
|
||||
except zmq.ZMQError as e:
|
||||
pass
|
||||
try:
|
||||
self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK)
|
||||
new_tracks = True
|
||||
except zmq.ZMQError as e:
|
||||
pass
|
||||
|
||||
if new_tracks:
|
||||
self.update_tracks()
|
||||
|
||||
def update_tracks(self):
|
||||
"""Updates the track objects and shapes. Called after setting `prediction_frame`
|
||||
"""
|
||||
|
||||
# clean up
|
||||
for track_id in list(self.drawn_tracks.keys()):
|
||||
if track_id not in self.prediction_frame.tracks.keys():
|
||||
# for track_id in list(self.drawn_tracks.keys()):
|
||||
# if track_id not in self.prediction_frame.tracks.keys():
|
||||
# # TODO fade out
|
||||
# del self.drawn_tracks[track_id]
|
||||
|
||||
if self.prediction_frame:
|
||||
for track_id, track in self.prediction_frame.tracks.items():
|
||||
if track_id not in self.drawn_tracks:
|
||||
self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H)
|
||||
else:
|
||||
self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H)
|
||||
|
||||
# clean up
|
||||
for track in self.drawn_tracks.values():
|
||||
if track.update_at < time.time() - 5:
|
||||
# TODO fade out
|
||||
del self.drawn_tracks[track_id]
|
||||
|
||||
|
||||
for track_id, track in self.prediction_frame.tracks.items():
|
||||
if track_id not in self.drawn_tracks:
|
||||
self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H)
|
||||
else:
|
||||
self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H)
|
||||
|
||||
|
||||
def on_key_press(self, symbol, modifiers):
|
||||
print('A key was pressed, use f to hide')
|
||||
if symbol == ord('f'):
|
||||
|
@ -550,6 +582,7 @@ class Renderer:
|
|||
def run(self):
|
||||
frame = None
|
||||
prediction_frame = None
|
||||
tracker_frame = None
|
||||
|
||||
i=0
|
||||
first_time = None
|
||||
|
@ -626,7 +659,7 @@ colorset = [(0, 0, 0),
|
|||
(255, 255, 0)
|
||||
]
|
||||
|
||||
|
||||
# Deprecated
|
||||
def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace) -> np.array:
|
||||
# TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
|
||||
# or https://github.com/pygobject/pycairo?tab=readme-ov-file
|
||||
|
|
Loading…
Reference in a new issue