Split tracker in renderer
This commit is contained in:
		
							parent
							
								
									6e98138cc2
								
							
						
					
					
						commit
						c3263e7448
					
				
					 1 changed files with 53 additions and 20 deletions
				
			
		| 
						 | 
					@ -58,6 +58,8 @@ def relativePolarToPoint(origin, r, angle) -> tuple[float, float]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DrawnTrack:
 | 
					class DrawnTrack:
 | 
				
			||||||
    def __init__(self, track_id, track: Track, renderer: Renderer, H):
 | 
					    def __init__(self, track_id, track: Track, renderer: Renderer, H):
 | 
				
			||||||
 | 
					        # self.created_at = time.time()
 | 
				
			||||||
 | 
					        self.update_at = self.created_at = time.time()
 | 
				
			||||||
        self.track_id = track_id
 | 
					        self.track_id = track_id
 | 
				
			||||||
        self.renderer = renderer
 | 
					        self.renderer = renderer
 | 
				
			||||||
        self.set_track(track, H)
 | 
					        self.set_track(track, H)
 | 
				
			||||||
| 
						 | 
					@ -67,6 +69,8 @@ class DrawnTrack:
 | 
				
			||||||
        self.pred_shapes: list[list[pyglet.shapes.Line]] = []
 | 
					        self.pred_shapes: list[list[pyglet.shapes.Line]] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_track(self, track: Track, H):
 | 
					    def set_track(self, track: Track, H):
 | 
				
			||||||
 | 
					        self.update_at = time.time()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.track = track
 | 
					        self.track = track
 | 
				
			||||||
        self.H = H
 | 
					        self.H = H
 | 
				
			||||||
        self.coords = [d.get_foot_coords() for d in track.history]
 | 
					        self.coords = [d.get_foot_coords() for d in track.history]
 | 
				
			||||||
| 
						 | 
					@ -239,6 +243,11 @@ class Renderer:
 | 
				
			||||||
        self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
 | 
					        self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
 | 
				
			||||||
        self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
 | 
					        self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        self.tracker_sock = context.socket(zmq.SUB)
 | 
				
			||||||
 | 
					        self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
 | 
				
			||||||
 | 
					        self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'')
 | 
				
			||||||
 | 
					        self.tracker_sock.connect(config.zmq_trajectory_addr)
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        self.frame_sock = context.socket(zmq.SUB)
 | 
					        self.frame_sock = context.socket(zmq.SUB)
 | 
				
			||||||
        self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
 | 
					        self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
 | 
				
			||||||
        self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
 | 
					        self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
 | 
				
			||||||
| 
						 | 
					@ -282,6 +291,7 @@ class Renderer:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.first_time: float|None = None
 | 
					        self.first_time: float|None = None
 | 
				
			||||||
        self.frame: Frame|None= None
 | 
					        self.frame: Frame|None= None
 | 
				
			||||||
 | 
					        self.tracker_frame: Frame|None = None
 | 
				
			||||||
        self.prediction_frame: Frame|None = None
 | 
					        self.prediction_frame: Frame|None = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -341,8 +351,9 @@ class Renderer:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_labels(self):
 | 
					    def init_labels(self):
 | 
				
			||||||
        base_color = (255,)*4
 | 
					        base_color = (255,)*4
 | 
				
			||||||
        info_color = (255,255,0, 255)
 | 
					        color_predictor = (255,255,0, 255)
 | 
				
			||||||
        info2_color = (255,0, 255, 255)
 | 
					        color_info = (255,0, 255, 255)
 | 
				
			||||||
 | 
					        color_tracker = (0,255, 255, 255)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        options = []
 | 
					        options = []
 | 
				
			||||||
        for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
 | 
					        for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
 | 
				
			||||||
| 
						 | 
					@ -351,11 +362,13 @@ class Renderer:
 | 
				
			||||||
        self.labels = {
 | 
					        self.labels = {
 | 
				
			||||||
            'waiting': pyglet.text.Label("Waiting for prediction"),
 | 
					            'waiting': pyglet.text.Label("Waiting for prediction"),
 | 
				
			||||||
            'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
 | 
					            'frame_idx': pyglet.text.Label("", x=20, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
 | 
				
			||||||
            'frame_time': pyglet.text.Label("t", x=120, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
 | 
					            'tracker_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
 | 
				
			||||||
            'pred_idx': pyglet.text.Label("", x=90, y=self.window.height - 17, color=info_color, batch=self.batch_overlay),
 | 
					            'pred_idx': pyglet.text.Label("", x=110, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
 | 
				
			||||||
            'frame_latency': pyglet.text.Label("", x=200, y=self.window.height - 17, color=info2_color, batch=self.batch_overlay),
 | 
					            'frame_time': pyglet.text.Label("t", x=140, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
 | 
				
			||||||
            'pred_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=info_color, batch=self.batch_overlay),
 | 
					            'frame_latency': pyglet.text.Label("", x=235, y=self.window.height - 17, color=color_info, batch=self.batch_overlay),
 | 
				
			||||||
            'track_len': pyglet.text.Label("", x=500, y=self.window.height - 17, color=base_color, batch=self.batch_overlay),
 | 
					            'tracker_time': pyglet.text.Label("", x=300, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
 | 
				
			||||||
 | 
					            'pred_time': pyglet.text.Label("", x=360, y=self.window.height - 17, color=color_predictor, batch=self.batch_overlay),
 | 
				
			||||||
 | 
					            'track_len': pyglet.text.Label("", x=800, y=self.window.height - 17, color=color_tracker, batch=self.batch_overlay),
 | 
				
			||||||
            'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay),
 | 
					            'options1': pyglet.text.Label(options.pop(-1), x=20, y=30, color=base_color, batch=self.batch_overlay),
 | 
				
			||||||
            'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay),
 | 
					            'options2': pyglet.text.Label(" | ".join(options), x=20, y=10, color=base_color, batch=self.batch_overlay),
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
| 
						 | 
					@ -368,10 +381,15 @@ class Renderer:
 | 
				
			||||||
            self.labels['frame_time'].text = f"{self.frame.time - self.first_time: >10.2f}s"
 | 
					            self.labels['frame_time'].text = f"{self.frame.time - self.first_time: >10.2f}s"
 | 
				
			||||||
            self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s"
 | 
					            self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s"
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        if self.tracker_frame:
 | 
				
			||||||
 | 
					            self.labels['tracker_idx'].text = f"{self.tracker_frame.index - self.frame.index}"
 | 
				
			||||||
 | 
					            self.labels['tracker_time'].text = f"{self.tracker_frame.time - time.time():.3f}s"
 | 
				
			||||||
 | 
					            self.labels['track_len'].text = f"{len(self.tracker_frame.tracks)} tracks"
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        if self.prediction_frame:
 | 
					        if self.prediction_frame:
 | 
				
			||||||
            self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}"
 | 
					            self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}"
 | 
				
			||||||
            self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s"
 | 
					            self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s"
 | 
				
			||||||
            self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
 | 
					            # self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
 | 
					        # cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
 | 
				
			||||||
| 
						 | 
					@ -397,6 +415,7 @@ class Renderer:
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def check_frames(self, dt):
 | 
					    def check_frames(self, dt):
 | 
				
			||||||
 | 
					        new_tracks = False
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
 | 
					            self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK)
 | 
				
			||||||
            if not self.first_time:
 | 
					            if not self.first_time:
 | 
				
			||||||
| 
						 | 
					@ -413,27 +432,40 @@ class Renderer:
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
 | 
					            self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK)
 | 
				
			||||||
            self.update_tracks()
 | 
					            new_tracks = True
 | 
				
			||||||
        except zmq.ZMQError as e:
 | 
					        except zmq.ZMQError as e:
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK)
 | 
				
			||||||
 | 
					            new_tracks = True
 | 
				
			||||||
 | 
					        except zmq.ZMQError as e:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        if new_tracks:
 | 
				
			||||||
 | 
					            self.update_tracks()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update_tracks(self):
 | 
					    def update_tracks(self):
 | 
				
			||||||
        """Updates the track objects and shapes. Called after setting `prediction_frame`
 | 
					        """Updates the track objects and shapes. Called after setting `prediction_frame`
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # clean up
 | 
					        # clean up
 | 
				
			||||||
        for track_id in list(self.drawn_tracks.keys()):
 | 
					        # for track_id in list(self.drawn_tracks.keys()):
 | 
				
			||||||
            if track_id not in self.prediction_frame.tracks.keys():
 | 
					        #     if track_id not in self.prediction_frame.tracks.keys():
 | 
				
			||||||
                # TODO fade out
 | 
					        #         # TODO fade out
 | 
				
			||||||
                del self.drawn_tracks[track_id]
 | 
					        #         del self.drawn_tracks[track_id]
 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        if self.prediction_frame:
 | 
				
			||||||
            for track_id, track in self.prediction_frame.tracks.items():
 | 
					            for track_id, track in self.prediction_frame.tracks.items():
 | 
				
			||||||
                if track_id not in self.drawn_tracks:
 | 
					                if track_id not in self.drawn_tracks:
 | 
				
			||||||
                    self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H)
 | 
					                    self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H)
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H)
 | 
					                    self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					        # clean up
 | 
				
			||||||
 | 
					        for track in self.drawn_tracks.values():
 | 
				
			||||||
 | 
					            if track.update_at < time.time() - 5:
 | 
				
			||||||
 | 
					                # TODO fade out
 | 
				
			||||||
 | 
					                del self.drawn_tracks[track_id]
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def on_key_press(self, symbol, modifiers):
 | 
					    def on_key_press(self, symbol, modifiers):
 | 
				
			||||||
        print('A key was pressed, use f to hide')
 | 
					        print('A key was pressed, use f to hide')
 | 
				
			||||||
| 
						 | 
					@ -550,6 +582,7 @@ class Renderer:
 | 
				
			||||||
    def run(self):
 | 
					    def run(self):
 | 
				
			||||||
        frame = None
 | 
					        frame = None
 | 
				
			||||||
        prediction_frame = None
 | 
					        prediction_frame = None
 | 
				
			||||||
 | 
					        tracker_frame = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        i=0
 | 
					        i=0
 | 
				
			||||||
        first_time = None
 | 
					        first_time = None
 | 
				
			||||||
| 
						 | 
					@ -626,7 +659,7 @@ colorset = [(0, 0, 0),
 | 
				
			||||||
 (255, 255, 0)
 | 
					 (255, 255, 0)
 | 
				
			||||||
 ]
 | 
					 ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Deprecated
 | 
				
			||||||
def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace) -> np.array:
 | 
					def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace) -> np.array:
 | 
				
			||||||
    # TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
 | 
					    # TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage
 | 
				
			||||||
    # or https://github.com/pygobject/pycairo?tab=readme-ov-file
 | 
					    # or https://github.com/pygobject/pycairo?tab=readme-ov-file
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue