diff --git a/trap/animation_renderer.py b/trap/animation_renderer.py new file mode 100644 index 0000000..43469a6 --- /dev/null +++ b/trap/animation_renderer.py @@ -0,0 +1,427 @@ +# used for "Forward Referencing of type annotations" +from __future__ import annotations + +import time +import ffmpeg +from argparse import Namespace +import datetime +import logging +from multiprocessing import Event +from multiprocessing.synchronize import Event as BaseEvent +import cv2 +import numpy as np + +import pyglet +import pyglet.event +import zmq +import tempfile +from pathlib import Path +import shutil +import math + +from pyglet import shapes +from PIL import Image + +from trap.frame_emitter import DetectionState, Frame, Track +from trap.preview_renderer import DrawnTrack, PROJECTION_IMG, PROJECTION_MAP + + +logger = logging.getLogger("trap.renderer") + +class AnimationRenderer: + def __init__(self, config: Namespace, is_running: BaseEvent): + self.config = config + self.is_running = is_running + + context = zmq.Context() + self.prediction_sock = context.socket(zmq.SUB) + self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! + self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'') + self.prediction_sock.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr) + + self.tracker_sock = context.socket(zmq.SUB) + self.tracker_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! + self.tracker_sock.setsockopt(zmq.SUBSCRIBE, b'') + self.tracker_sock.connect(config.zmq_trajectory_addr) + + self.frame_sock = context.socket(zmq.SUB) + self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! + self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'') + self.frame_sock.connect(config.zmq_frame_addr) + + self.H = self.config.H + + self.inv_H = np.linalg.pinv(self.H) + + # TODO: get FPS from frame_emitter + # self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720)) + self.fps = 60 + self.frame_size = (self.config.frame_width,self.config.frame_height) + self.hide_stats = False + self.out_writer = None # self.start_writer() if self.config.render_file else None + self.streaming_process = None # self.start_streaming() if self.config.render_url else None + + if self.config.render_window: + pass + # cv2.namedWindow("frame", cv2.WND_PROP_FULLSCREEN) + # cv2.setWindowProperty("frame",cv2.WND_PROP_FULLSCREEN,cv2.WINDOW_FULLSCREEN) + else: + pyglet.options["headless"] = True + + config = pyglet.gl.Config(sample_buffers=1, samples=4) + # , fullscreen=self.config.render_window + self.window = pyglet.window.Window(width=self.frame_size[0], height=self.frame_size[1], config=config, fullscreen=self.config.full_screen) + self.window.set_handler('on_draw', self.on_draw) + self.window.set_handler('on_refresh', self.on_refresh) + self.window.set_handler('on_close', self.on_close) + + pyglet.gl.glClearColor(0,0,0, 0) + self.fps_display = pyglet.window.FPSDisplay(window=self.window, color=(255,255,255,255)) + self.fps_display.label.x = self.window.width - 50 + self.fps_display.label.y = self.window.height - 17 + self.fps_display.label.bold = False + self.fps_display.label.font_size = 10 + + self.drawn_tracks: dict[str, DrawnTrack] = {} + + + self.first_time: float|None = None + self.frame: Frame|None= None + self.tracker_frame: Frame|None = None + self.prediction_frame: Frame|None = None + + + self.batch_bg = pyglet.graphics.Batch() + self.batch_overlay = pyglet.graphics.Batch() + self.batch_anim = pyglet.graphics.Batch() + + self.init_shapes() + + self.init_labels() + + + def init_shapes(self): + ''' + Due to error when running headless, we need to configure options before extending the shapes class + ''' + class GradientLine(shapes.Line): + def __init__(self, x, y, x2, y2, width=1, color1=[255,255,255], color2=[255,255,255], batch=None, group=None): + # print('colors!', colors) + # assert len(colors) == 6 + + r, g, b, *a = color1 + self._rgba1 = (r, g, b, a[0] if a else 255) + r, g, b, *a = color2 + self._rgba2 = (r, g, b, a[0] if a else 255) + + # print('rgba', self._rgba) + + super().__init__(x, y, x2, y2, width, color1, batch=None, group=None) + # 10.2f}s" + self.labels['frame_latency'].text = f"{self.frame.time - time.time():.2f}s" + + if self.tracker_frame: + self.labels['tracker_idx'].text = f"{self.tracker_frame.index - self.frame.index}" + self.labels['tracker_time'].text = f"{self.tracker_frame.time - time.time():.3f}s" + self.labels['track_len'].text = f"{len(self.tracker_frame.tracks)} tracks" + + if self.prediction_frame: + self.labels['pred_idx'].text = f"{self.prediction_frame.index - self.frame.index}" + self.labels['pred_time'].text = f"{self.prediction_frame.time - time.time():.3f}s" + # self.labels['track_len'].text = f"{len(self.prediction_frame.tracks)} tracks" + + + # cv2.putText(img, f"{frame.index:06d}", (20,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1) + # cv2.putText(img, f"{frame.time - first_time:.3f}s", (120,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1) + + # if prediction_frame: + # # render Δt and Δ frames + # cv2.putText(img, f"{prediction_frame.index - frame.index}", (90,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1) + # cv2.putText(img, f"{prediction_frame.time - time.time():.2f}s", (200,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1) + # cv2.putText(img, f"{len(prediction_frame.tracks)} tracks", (500,17), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1) + # cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()]):.2f}", (580,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1) + # cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()]):.2f}", (660,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1) + # cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()]):.2f}", (740,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1) + + # options = [] + # for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']: + # options.append(f"{option}: {config.__dict__[option]}") + + + # cv2.putText(img, options.pop(-1), (20,img.shape[0]-30), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1) + # cv2.putText(img, " | ".join(options), (20,img.shape[0]-10), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1) + + + + def check_frames(self, dt): + new_tracks = False + try: + self.frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK) + if not self.first_time: + self.first_time = self.frame.time + img = self.frame.img + img = cv2.warpPerspective(img, self.H, (self.frame.img.shape[1], self.frame.img.shape[0])) + img = cv2.GaussianBlur(img, (15, 15), 0) + img = cv2.flip(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), 0) + img = pyglet.image.ImageData(self.frame_size[0], self.frame_size[1], 'RGB', img.tobytes()) + # don't draw in batch, so that it is the background + self.video_sprite = pyglet.sprite.Sprite(img=img, batch=self.batch_bg) + self.video_sprite.opacity = 100 + except zmq.ZMQError as e: + # idx = frame.index if frame else "NONE" + # logger.debug(f"reuse video frame {idx}") + pass + try: + self.prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK) + new_tracks = True + except zmq.ZMQError as e: + pass + try: + self.tracker_frame: Frame = self.tracker_sock.recv_pyobj(zmq.NOBLOCK) + new_tracks = True + except zmq.ZMQError as e: + pass + + if new_tracks: + self.update_tracks() + + def update_tracks(self): + """Updates the track objects and shapes. Called after setting `prediction_frame` + """ + + # clean up + # for track_id in list(self.drawn_tracks.keys()): + # if track_id not in self.prediction_frame.tracks.keys(): + # # TODO fade out + # del self.drawn_tracks[track_id] + + if self.prediction_frame: + for track_id, track in self.prediction_frame.tracks.items(): + if track_id not in self.drawn_tracks: + self.drawn_tracks[track_id] = DrawnTrack(track_id, track, self, self.prediction_frame.H, PROJECTION_MAP) + else: + self.drawn_tracks[track_id].set_track(track, self.prediction_frame.H) + + # clean up + for track_id in list(self.drawn_tracks.keys()): + # TODO make delay configurable + if self.drawn_tracks[track_id].update_at < time.time() - 5: + # TODO fade out + del self.drawn_tracks[track_id] + + + def on_key_press(self, symbol, modifiers): + print('A key was pressed, use f to hide') + if symbol == ord('f'): + self.window.set_fullscreen(not self.window.fullscreen) + if symbol == ord('h'): + self.hide_stats = not self.hide_stats + + def check_running(self, dt): + if not self.is_running.is_set(): + self.window.close() + self.event_loop.exit() + + def on_close(self): + self.is_running.clear() + + def on_refresh(self, dt: float): + # update shapes + # self.bg = + for track_id, track in self.drawn_tracks.items(): + track.update_drawn_positions(dt) + + + self.refresh_labels(dt) + + # self.shape1 = shapes.Circle(700, 150, 100, color=(50, 0, 30), batch=self.batch_anim) + # self.shape3 = shapes.Circle(800, 150, 100, color=(100, 225, 30), batch=self.batch_anim) + pass + + def on_draw(self): + self.window.clear() + + self.batch_bg.draw() + + for track in self.drawn_tracks.values(): + for shape in track.shapes: + shape.draw() # for some reason the batches don't work + for track in self.drawn_tracks.values(): + for shapes in track.pred_shapes: + for shape in shapes: + shape.draw() + # self.batch_anim.draw() + self.batch_overlay.draw() + + + # pyglet.graphics.draw(3, pyglet.gl.GL_LINE, ("v2i", (100,200, 600,800)), ('c3B', (255,255,255, 255,255,255))) + + + + if not self.hide_stats: + self.fps_display.draw() + + # if streaming, capture buffer and send + try: + if self.streaming_process or self.out_writer: + buf = pyglet.image.get_buffer_manager().get_color_buffer() + img_data = buf.get_image_data() + data = img_data.get_data() # alternative: .get_data("RGBA", image_data.pitch) + img = np.asanyarray(data).reshape((img_data.height, img_data.width, 4)) + img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) + img = np.flip(img, 0) + # img = cv2.flip(img, cv2.0) + + # cv2.imshow('frame', img) + # cv2.waitKey(1) + if self.streaming_process: + self.streaming_process.stdin.write(img.tobytes()) + if self.out_writer: + self.out_writer.write(img) + except Exception as e: + logger.exception(e) + + + + def run(self): + frame = None + prediction_frame = None + tracker_frame = None + + i=0 + first_time = None + + self.event_loop = pyglet.app.EventLoop() + pyglet.clock.schedule_interval(self.check_running, 0.1) + pyglet.clock.schedule(self.check_frames) + self.event_loop.run() + + + + # while self.is_running.is_set(): + # i+=1 + + + # # zmq_ev = self.frame_sock.poll(timeout=2000) + # # if not zmq_ev: + # # # when no data comes in, loop so that is_running is checked + # # continue + + # try: + # frame: Frame = self.frame_sock.recv_pyobj(zmq.NOBLOCK) + # except zmq.ZMQError as e: + # # idx = frame.index if frame else "NONE" + # # logger.debug(f"reuse video frame {idx}") + # pass + # # else: + # # logger.debug(f'new video frame {frame.index}') + + + # if frame is None: + # # might need to wait a few iterations before first frame comes available + # time.sleep(.1) + # continue + + # try: + # prediction_frame: Frame = self.prediction_sock.recv_pyobj(zmq.NOBLOCK) + # except zmq.ZMQError as e: + # logger.debug(f'reuse prediction') + + # if first_time is None: + # first_time = frame.time + + # img = decorate_frame(frame, prediction_frame, first_time, self.config) + + # img_path = (self.config.output_dir / f"{i:05d}.png").resolve() + + # logger.debug(f"write frame {frame.time - first_time:.3f}s") + # if self.out_writer: + # self.out_writer.write(img) + # if self.streaming_process: + # self.streaming_process.stdin.write(img.tobytes()) + # if self.config.render_window: + # cv2.imshow('frame',img) + # cv2.waitKey(1) + logger.info('Stopping') + + # if i>2: + if self.streaming_process: + self.streaming_process.stdin.close() + if self.out_writer: + self.out_writer.release() + if self.streaming_process: + # oddly wrapped, because both close and release() take time. + self.streaming_process.wait() + +# colorset = itertools.product([0,255], repeat=3) # but remove white +colorset = [(0, 0, 0), + (0, 0, 255), + (0, 255, 0), + (0, 255, 255), + (255, 0, 0), + (255, 0, 255), + (255, 255, 0) + ] + + + +def run_animation_renderer(config: Namespace, is_running: BaseEvent): + renderer = AnimationRenderer(config, is_running) + renderer.run() \ No newline at end of file diff --git a/trap/config.py b/trap/config.py index 605a4bd..42efe3f 100644 --- a/trap/config.py +++ b/trap/config.py @@ -1,6 +1,8 @@ import argparse from pathlib import Path import types +import numpy as np +import json from trap.tracker import DETECTORS @@ -49,6 +51,21 @@ frame_emitter_parser = parser.add_argument_group('Frame emitter') tracker_parser = parser.add_argument_group('Tracker') render_parser = parser.add_argument_group('Renderer') +class HomographyAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super().__init__(option_strings, dest, **kwargs) + def __call__(self, parser, namespace, values: Path, option_string=None): + if values.suffix == '.json': + with values.open('r') as fp: + H = np.array(json.load(fp)) + else: + H = np.loadtxt(values, delimiter=',') + print('%r %r %r' % (namespace, values, option_string)) + setattr(namespace, self.dest, values) + setattr(namespace, 'H', H) + inference_parser.add_argument("--model_dir", help="directory with the model to use for inference", type=str, # TODO: make into Path @@ -234,7 +251,8 @@ frame_emitter_parser.add_argument("--video-loop", tracker_parser.add_argument("--homography", help="File with homography params", type=Path, - default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt') + default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt', + action=HomographyAction) tracker_parser.add_argument("--save-for-training", help="Specify the path in which to save", type=Path, @@ -246,6 +264,15 @@ tracker_parser.add_argument("--detector", tracker_parser.add_argument("--smooth-tracks", help="Smooth the tracker tracks before sending them to the predictor", action='store_true') +tracker_parser.add_argument("--frame-width", + help="width of the frames", + type=int, + default=1280) +tracker_parser.add_argument("--frame-height", + help="height of the frames", + type=int, + default=720) + # Renderer diff --git a/trap/frame_emitter.py b/trap/frame_emitter.py index f07a147..67f5da5 100644 --- a/trap/frame_emitter.py +++ b/trap/frame_emitter.py @@ -85,7 +85,7 @@ class Track: def get_projected_history(self, H) -> np.array: foot_coordinates = [d.get_foot_coords() for d in self.history] - + # TODO)) Undistort points before perspective transform if len(foot_coordinates): coords = cv2.perspectiveTransform(np.array([foot_coordinates]),H) return coords[0] @@ -151,8 +151,8 @@ class FrameEmitter: # numeric input is a CV camera video = cv2.VideoCapture(int(str(video_path))) # TODO: make config variables - video.set(cv2.CAP_PROP_FRAME_WIDTH, int(1280)) - video.set(cv2.CAP_PROP_FRAME_HEIGHT, int(720)) + video.set(cv2.CAP_PROP_FRAME_WIDTH, int(self.config.frame_width)) + video.set(cv2.CAP_PROP_FRAME_HEIGHT, int(self.config.frame_height)) print("exposure!", video.get(cv2.CAP_PROP_AUTO_EXPOSURE)) video.set(cv2.CAP_PROP_FPS, 5) else: diff --git a/trap/plumber.py b/trap/plumber.py index a70d0f5..d148813 100644 --- a/trap/plumber.py +++ b/trap/plumber.py @@ -9,7 +9,8 @@ import time from trap.config import parser from trap.frame_emitter import run_frame_emitter from trap.prediction_server import run_prediction_server -from trap.renderer import run_renderer +from trap.preview_renderer import run_preview_renderer +from trap.animation_renderer import run_animation_renderer from trap.socket_forwarder import run_ws_forwarder from trap.tracker import run_tracker @@ -75,7 +76,10 @@ def start(): if args.render_file or args.render_url or args.render_window: procs.append( - ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer') + ExceptionHandlingProcess(target=run_preview_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer') + ) + procs.append( + ExceptionHandlingProcess(target=run_animation_renderer, kwargs={'config': args, 'is_running': isRunning}, name='map_renderer') ) if not args.bypass_prediction: diff --git a/trap/renderer.py b/trap/preview_renderer.py similarity index 96% rename from trap/renderer.py rename to trap/preview_renderer.py index 3e3593e..6bc705d 100644 --- a/trap/renderer.py +++ b/trap/preview_renderer.py @@ -10,7 +10,7 @@ from multiprocessing import Event from multiprocessing.synchronize import Event as BaseEvent import cv2 import numpy as np - +import json import pyglet import pyglet.event import zmq @@ -26,7 +26,7 @@ from trap.frame_emitter import DetectionState, Frame, Track -logger = logging.getLogger("trap.renderer") +logger = logging.getLogger("trap.preview") class FrameAnimation: def __init__(self, frame: Frame): @@ -55,10 +55,15 @@ def relativePointToPolar(origin, point) -> tuple[float, float]: def relativePolarToPoint(origin, r, angle) -> tuple[float, float]: return r * np.cos(angle) + origin[0], r * np.sin(angle) + origin[1] +PROJECTION_IMG = 0 +PROJECTION_UNDISTORT = 1 +PROJECTION_MAP = 2 +PROJECTION_PROJECTOR = 4 class DrawnTrack: - def __init__(self, track_id, track: Track, renderer: Renderer, H): + def __init__(self, track_id, track: Track, renderer: PreviewRenderer, H, draw_projection = PROJECTION_IMG): # self.created_at = time.time() + self.draw_projection = draw_projection self.update_at = self.created_at = time.time() self.track_id = track_id self.renderer = renderer @@ -73,14 +78,17 @@ class DrawnTrack: self.track = track self.H = H - self.coords = [d.get_foot_coords() for d in track.history] + self.coords = [d.get_foot_coords() for d in track.history] if self.draw_projection == PROJECTION_IMG else track.get_projected_history(self.H) # perhaps only do in constructor: self.inv_H = np.linalg.pinv(self.H) pred_coords = [] - for pred_i, pred in enumerate(track.predictions): - pred_coords.append(cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0].tolist()) + if self.draw_projection == PROJECTION_IMG: + for pred_i, pred in enumerate(track.predictions): + pred_coords.append(cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0].tolist()) + elif self.draw_projection == PROJECTION_MAP: + pred_coords = [pred for pred in track.predictions] self.pred_coords = pred_coords # color = (128,0,128) if pred_i else (128, @@ -232,7 +240,7 @@ class FrameWriter: -class Renderer: +class PreviewRenderer: def __init__(self, config: Namespace, is_running: BaseEvent): self.config = config self.is_running = is_running @@ -253,14 +261,23 @@ class Renderer: self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'') self.frame_sock.connect(config.zmq_frame_addr) - self.H = np.loadtxt(self.config.homography, delimiter=',') + + # TODO)) Move loading H to config.py + # if self.config.homography.suffix == '.json': + # with self.config.homography.open('r') as fp: + # self.H = np.array(json.load(fp)) + # else: + # self.H = np.loadtxt(self.config.homography, delimiter=',') + print('h', self.config.H) + self.H = self.config.H + self.inv_H = np.linalg.pinv(self.H) # TODO: get FPS from frame_emitter # self.out = cv2.VideoWriter(str(filename), fourcc, 23.97, (1280,720)) self.fps = 60 - self.frame_size = (1280,720) + self.frame_size = (self.config.frame_width,self.config.frame_height) self.hide_stats = False self.out_writer = self.start_writer() if self.config.render_file else None self.streaming_process = self.start_streaming() if self.config.render_url else None @@ -772,6 +789,6 @@ def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, con return img -def run_renderer(config: Namespace, is_running: BaseEvent): - renderer = Renderer(config, is_running) +def run_preview_renderer(config: Namespace, is_running: BaseEvent): + renderer = PreviewRenderer(config, is_running) renderer.run() \ No newline at end of file diff --git a/trap/tracker.py b/trap/tracker.py index 1e2f0ce..ed2abff 100644 --- a/trap/tracker.py +++ b/trap/tracker.py @@ -105,7 +105,7 @@ class Tracker: # homography = list(source.glob('*img2world.txt'))[0] - self.H = np.loadtxt(self.config.homography, delimiter=',') + self.H = self.config.H if self.config.smooth_tracks: logger.info("Smoother enabled")