From c3263e74489901aff03eb17abfe1cd8c48b066d7 Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Mon, 30 Sep 2024 15:42:06 +0200 Subject: [PATCH] Split tracker in renderer --- trap/renderer.py | 73 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/trap/renderer.py b/trap/renderer.py index 04e0a06..ce3348e 100644 --- a/trap/renderer.py +++ b/trap/renderer.py @@ -58,6 +58,8 @@ def relativePolarToPoint(origin, r, angle) -> tuple[float, float]: class DrawnTrack: def __init__(self, track_id, track: Track, renderer: Renderer, H): + # self.created_at = time.time() + self.update_at = self.created_at = time.time() self.track_id = track_id self.renderer = renderer self.set_track(track, H) @@ -67,6 +69,8 @@ class DrawnTrack: self.pred_shapes: list[list[pyglet.shapes.Line]] = [] def set_track(self, track: Track, H): + self.update_at = time.time() + self.track = track self.H = H self.coords = [d.get_foot_coords() for d in track.history] @@ -239,6 +243,11 @@ class Renderer: self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'') self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr) + self.tracker_sock = context.socket(zmq.SUB) + self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! + self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'') + self.tracker_sock.connect(config.zmq_trajectory_addr) + self.frame_sock = context.socket(zmq.SUB) self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'') @@ -282,6 +291,7 @@ class Renderer: self.first_time: float|None = None self.frame: Frame|None= None + self.tracker_frame: Frame|None = None self.prediction_frame: Frame|None = None @@ -341,8 +351,9 @@ class Renderer: def init_labels(self): base_color = (255,)*4 - info_color = (255,255,0, 255) - info2_color = (255,0, 255, 255) + color_predictor = (255,255,0, 255) + color_info = (255,0, 255, 255) + color_tracker = (0,255, 255, 255) options = [] for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']: @@ -351,11 +362,13 @@ class Renderer: self.labels = { 'waiting': pyglet.text.Label("Waiting for prediction"), 'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay), - 'frame_time': pyglet.text.Label("t", x=120, y=self.window.height - 17, color=base_color, batch=self.batch_overlay), - 'pred_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=info_color, batch=self.batch_overlay), - 'frame_latency': pyglet.text.Label("", x=200, y=self.window.height - 17, color=info2_color, batch=self.batch_overlay), - 'pred_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=info_color, batch=self.batch_overlay), - 'track_len': pyglet.text.Label("", x=500, y=self.window.height - 17, color=base_color, batch=self.batch_overlay), + 'tracker_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay), + 'pred_idx': pyglet.text.Label("", x=110, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay), + 'frame_time': pyglet.text.Label("t", x=140, y=self.window.height - 17, color=base_color, batch=self.batch_overlay), + 'frame_latency': pyglet.text.Label("", x=235, y=self.window.height - 17, color=color_info, batch=self.batch_overlay), + 'tracker_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay), + 'pred_time': pyglet.text.Label("", x=360, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay), + 'track_len': pyglet.text.Label("", x=800, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay), 'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay), 'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay), } @@ -368,10 +381,15 @@ class Renderer: self.labels['frame_time'].text = f"{self.frame.time - self.first_time: >10.2f}s" self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s" + if self.tracker_frame: + self.labels['tracker_idx'].text = f"{self.tracker_frame.index - self.frame.index}" + self.labels['tracker_time'].text = f"{self.tracker_frame.time - time.time():.3f}s" + self.labels['track_len'].text = f"{len(self.tracker_frame.tracks)} tracks" + if self.prediction_frame: self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}" self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s" - self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks" + # self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks" # cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1) @@ -397,6 +415,7 @@ class Renderer: def check_frames(self, dt): + new_tracks = False try: self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK) if not self.first_time: @@ -413,27 +432,40 @@ class Renderer: pass try: self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK) - self.update_tracks() + new_tracks = True except zmq.ZMQError as e: pass + try: + self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK) + new_tracks = True + except zmq.ZMQError as e: + pass + + if new_tracks: + self.update_tracks() def update_tracks(self): """Updates the track objects and shapes. Called after setting `prediction_frame` """ # clean up - for track_id in list(self.drawn_tracks.keys()): - if track_id not in self.prediction_frame.tracks.keys(): + # for track_id in list(self.drawn_tracks.keys()): + # if track_id not in self.prediction_frame.tracks.keys(): + # # TODO fade out + # del self.drawn_tracks[track_id] + + if self.prediction_frame: + for track_id, track in self.prediction_frame.tracks.items(): + if track_id not in self.drawn_tracks: + self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H) + else: + self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H) + + # clean up + for track in self.drawn_tracks.values(): + if track.update_at < time.time() - 5: # TODO fade out del self.drawn_tracks[track_id] - - - for track_id, track in self.prediction_frame.tracks.items(): - if track_id not in self.drawn_tracks: - self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H) - else: - self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H) - def on_key_press(self, symbol, modifiers): print('A key was pressed, use f to hide') @@ -550,6 +582,7 @@ class Renderer: def run(self): frame = None prediction_frame = None + tracker_frame = None i=0 first_time = None @@ -626,7 +659,7 @@ colorset = [(0, 0, 0), (255, 255, 0) ] - +# Deprecated def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace) -> np.array: # TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage # or https://github.com/pygobject/pycairo?tab=readme-ov-file