Split tracker in renderer

This commit is contained in:
Ruben van de Ven 2024-09-30 15:42:06 +02:00
parent 6e98138cc2
commit c3263e7448

View file

@ -58,6 +58,8 @@ def relativePolarToPoint(origin, r, angle) -> tuple[float, float]:
class DrawnTrack: class DrawnTrack:
def __init__(self, track_id, track: Track, renderer: Renderer, H): def __init__(self, track_id, track: Track, renderer: Renderer, H):
# self.created_at = time.time()
self.update_at = self.created_at = time.time()
self.track_id = track_id self.track_id = track_id
self.renderer = renderer self.renderer = renderer
self.set_track(track, H) self.set_track(track, H)
@ -67,6 +69,8 @@ class DrawnTrack:
self.pred_shapes: list[list[pyglet.shapes.Line]] = [] self.pred_shapes: list[list[pyglet.shapes.Line]] = []
def set_track(self, track: Track, H): def set_track(self, track: Track, H):
self.update_at = time.time()
self.track = track self.track = track
self.H = H self.H = H
self.coords = [d.get_foot_coords() for d in track.history] self.coords = [d.get_foot_coords() for d in track.history]
@ -239,6 +243,11 @@ class Renderer:
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'') self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr) self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
self.tracker_sock = context.socket(zmq.SUB)
self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.tracker_sock.connect(config.zmq_trajectory_addr)
self.frame_sock = context.socket(zmq.SUB) self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'') self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
@ -282,6 +291,7 @@ class Renderer:
self.first_time: float|None = None self.first_time: float|None = None
self.frame: Frame|None= None self.frame: Frame|None= None
self.tracker_frame: Frame|None = None
self.prediction_frame: Frame|None = None self.prediction_frame: Frame|None = None
@ -341,8 +351,9 @@ class Renderer:
def init_labels(self): def init_labels(self):
base_color = (255,)*4 base_color = (255,)*4
info_color = (255,255,0, 255) color_predictor = (255,255,0, 255)
info2_color = (255,0, 255, 255) color_info = (255,0, 255, 255)
color_tracker = (0,255, 255, 255)
options = [] options = []
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']: for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
@ -351,11 +362,13 @@ class Renderer:
self.labels = { self.labels = {
'waiting': pyglet.text.Label("Waiting for prediction"), 'waiting': pyglet.text.Label("Waiting for prediction"),
'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay), 'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
'frame_time': pyglet.text.Label("t", x=120, y=self.window.height - 17, color=base_color, batch=self.batch_overlay), 'tracker_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
'pred_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=info_color, batch=self.batch_overlay), 'pred_idx': pyglet.text.Label("", x=110, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
'frame_latency': pyglet.text.Label("", x=200, y=self.window.height - 17, color=info2_color, batch=self.batch_overlay), 'frame_time': pyglet.text.Label("t", x=140, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
'pred_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=info_color, batch=self.batch_overlay), 'frame_latency': pyglet.text.Label("", x=235, y=self.window.height - 17, color=color_info, batch=self.batch_overlay),
'track_len': pyglet.text.Label("", x=500, y=self.window.height - 17, color=base_color, batch=self.batch_overlay), 'tracker_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
'pred_time': pyglet.text.Label("", x=360, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
'track_len': pyglet.text.Label("", x=800, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay), 'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay),
'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay), 'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay),
} }
@ -368,10 +381,15 @@ class Renderer:
self.labels['frame_time'].text = f"{self.frame.time - self.first_time: >10.2f}s" self.labels['frame_time'].text = f"{self.frame.time - self.first_time: >10.2f}s"
self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s" self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s"
if self.tracker_frame:
self.labels['tracker_idx'].text = f"{self.tracker_frame.index - self.frame.index}"
self.labels['tracker_time'].text = f"{self.tracker_frame.time - time.time():.3f}s"
self.labels['track_len'].text = f"{len(self.tracker_frame.tracks)} tracks"
if self.prediction_frame: if self.prediction_frame:
self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}" self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}"
self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s" self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s"
self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks" # self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
# cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1) # cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
@ -397,6 +415,7 @@ class Renderer:
def check_frames(self, dt): def check_frames(self, dt):
new_tracks = False
try: try:
self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK) self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
if not self.first_time: if not self.first_time:
@ -413,27 +432,40 @@ class Renderer:
pass pass
try: try:
self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK) self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
self.update_tracks() new_tracks = True
except zmq.ZMQError as e: except zmq.ZMQError as e:
pass pass
try:
self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK)
new_tracks = True
except zmq.ZMQError as e:
pass
if new_tracks:
self.update_tracks()
def update_tracks(self): def update_tracks(self):
"""Updates the track objects and shapes. Called after setting `prediction_frame` """Updates the track objects and shapes. Called after setting `prediction_frame`
""" """
# clean up # clean up
for track_id in list(self.drawn_tracks.keys()): # for track_id in list(self.drawn_tracks.keys()):
if track_id not in self.prediction_frame.tracks.keys(): # if track_id not in self.prediction_frame.tracks.keys():
# TODO fade out # # TODO fade out
del self.drawn_tracks[track_id] # del self.drawn_tracks[track_id]
if self.prediction_frame:
for track_id, track in self.prediction_frame.tracks.items(): for track_id, track in self.prediction_frame.tracks.items():
if track_id not in self.drawn_tracks: if track_id not in self.drawn_tracks:
self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H) self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H)
else: else:
self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H) self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H)
# clean up
for track in self.drawn_tracks.values():
if track.update_at < time.time() - 5:
# TODO fade out
del self.drawn_tracks[track_id]
def on_key_press(self, symbol, modifiers): def on_key_press(self, symbol, modifiers):
print('A key was pressed, use f to hide') print('A key was pressed, use f to hide')
@ -550,6 +582,7 @@ class Renderer:
def run(self): def run(self):
frame = None frame = None
prediction_frame = None prediction_frame = None
tracker_frame = None
i=0 i=0
first_time = None first_time = None
@ -626,7 +659,7 @@ colorset = [(0, 0, 0),
(255, 255, 0) (255, 255, 0)
] ]
# Deprecated
def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace) -> np.array: def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace) -> np.array:
# TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage # TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
# or https://github.com/pygobject/pycairo?tab=readme-ov-file # or https://github.com/pygobject/pycairo?tab=readme-ov-file