From c4d498d5513a0d49942381b1f0947dbf58919b92 Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Thu, 2 Jan 2025 16:24:59 +0100 Subject: [PATCH] Draw arrowed clusters instead of individual detections --- trap/config.py | 3 +++ trap/cv_renderer.py | 10 +++++----- trap/tools.py | 9 +++++---- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/trap/config.py b/trap/config.py index 87194b9..ac21f2e 100644 --- a/trap/config.py +++ b/trap/config.py @@ -353,6 +353,9 @@ render_parser.add_argument("--render-hide-stats", render_parser.add_argument("--full-screen", help="Set Window full screen", action='store_true') +render_parser.add_argument("--render-clusters", + help="renders arrowd clusters instead of individual predictions", + action='store_true') render_parser.add_argument("--render-url", help="""Stream renderer on given URL. Two easy approaches: diff --git a/trap/cv_renderer.py b/trap/cv_renderer.py index 26f98a4..bedc49e 100644 --- a/trap/cv_renderer.py +++ b/trap/cv_renderer.py @@ -392,7 +392,7 @@ class CvRenderer: if first_time is None: first_time = frame.time - img = decorate_frame(frame, tracker_frame, prediction_frame, first_time, self.config, self.tracks, self.predictions) + img = decorate_frame(frame, tracker_frame, prediction_frame, first_time, self.config, self.tracks, self.predictions, self.config.render_clusters) logger.debug(f"write frame {frame.time - first_time:.3f}s") if self.out_writer: @@ -456,7 +456,7 @@ def get_animation_position(track: Track, current_frame: Frame): # Deprecated -def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace, tracks: Dict[str, Track], predictions: Dict[str, Track]) -> np.array: +def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace, tracks: Dict[str, Track], predictions: Dict[str, Track], as_clusters = True) -> np.array: # TODO: replace opencv with QPainter to support alpha? https://doc.qt.io/qtforpython-5/PySide2/QtGui/QPainter.html#PySide2.QtGui.PySide2.QtGui.QPainter.drawImage # or https://github.com/pygobject/pycairo?tab=readme-ov-file # or https://pyglet.readthedocs.io/en/latest/programming_guide/shapes.html @@ -498,10 +498,10 @@ def decorate_frame(frame: Frame, tracker_frame: Frame, prediction_frame: Frame, else: for track_id, track in predictions.items(): inv_H = np.linalg.pinv(prediction_frame.H) - # draw_track(img, track, int(track_id)) - draw_trackjectron_history(img, track, int(track.track_id), convert_world_points_to_img_points) + # For debugging: + # draw_trackjectron_history(img, track, int(track.track_id), convert_world_points_to_img_points) anim_position = get_animation_position(track, frame) - draw_track_predictions(img, track, int(track.track_id)+1, config.camera, convert_world_points_to_img_points, anim_position=anim_position, as_clusters=True) + draw_track_predictions(img, track, int(track.track_id)+1, config.camera, convert_world_points_to_img_points, anim_position=anim_position, as_clusters=as_clusters) cv2.putText(img, f"{len(track.predictor_history) if track.predictor_history else 'none'}", to_point(track.history[0].get_foot_coords()), cv2.FONT_HERSHEY_COMPLEX, 1, (255,255,255), 1) if prediction_frame.maps: for i, m in enumerate(prediction_frame.maps): diff --git a/trap/tools.py b/trap/tools.py index 902b9f4..315efec 100644 --- a/trap/tools.py +++ b/trap/tools.py @@ -362,11 +362,12 @@ def draw_track_predictions(img: cv2.Mat, track: Track, color_index: int, camera: # cv2 only draws to integer coordinates points = np.rint(points).astype(int) thickness = max(1, int(cluster.probability * 6)) - if len(cluster.next_point_clusters) == 1: + thickness=1 + # if len(cluster.next_point_clusters) == 1: # not a final point, nor a split: - cv2.line(img, points[0], points[1], color, thickness, lineType=cv2.LINE_AA) - else: - cv2.arrowedLine(img, points[0], points[1], color, thickness, cv2.LINE_AA) + cv2.line(img, points[0], points[1], color, thickness, lineType=cv2.LINE_AA) + # else: + # cv2.arrowedLine(img, points[0], points[1], color, thickness, cv2.LINE_AA) for sub in cluster.next_point_clusters: draw_cluster(img, sub)