diff --git a/trap/config.py b/trap/config.py index 6f763f5..8471c04 100644 --- a/trap/config.py +++ b/trap/config.py @@ -1,6 +1,8 @@ import argparse from pathlib import Path +from pyparsing import Optional + parser = argparse.ArgumentParser() @@ -27,6 +29,7 @@ inference_parser = parser.add_argument_group('Inference') connection_parser = parser.add_argument_group('Connection') frame_emitter_parser = parser.add_argument_group('Frame emitter') tracker_parser = parser.add_argument_group('Tracker') +render_parser = parser.add_argument_group('Renderer') inference_parser.add_argument("--model_dir", help="directory with the model to use for inference", @@ -106,7 +109,7 @@ inference_parser.add_argument("--eval_data_dict", inference_parser.add_argument("--output_dir", help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)", - type=str, + type=Path, default='./OUT/test_inference') @@ -174,3 +177,11 @@ tracker_parser.add_argument("--homography", help="File with homography params", type=Path, default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt') + +# Renderer + +# render_parser.add_argument("--output-dir", +# help="Target image dir", +# type=Optional[Path], +# default=None) + diff --git a/trap/plumber.py b/trap/plumber.py index a33f487..2f9a6b7 100644 --- a/trap/plumber.py +++ b/trap/plumber.py @@ -5,6 +5,7 @@ import sys from trap.config import parser from trap.frame_emitter import run_frame_emitter from trap.prediction_server import run_prediction_server +from trap.renderer import run_renderer from trap.socket_forwarder import run_ws_forwarder from trap.tracker import run_tracker @@ -52,6 +53,7 @@ def start(): ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'), ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'), ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'), + ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer'), ] if not args.bypass_prediction: procs.append( diff --git a/trap/prediction_server.py b/trap/prediction_server.py index 772c1ad..ba68b68 100644 --- a/trap/prediction_server.py +++ b/trap/prediction_server.py @@ -4,8 +4,11 @@ import logging from multiprocessing import Event, Queue import os import pickle +import sys import time import json +import traceback +import warnings import pandas as pd import torch import dill @@ -301,7 +304,7 @@ class PredictionServer: start = time.time() dists, preds = trajectron.incremental_forward(input_dict, maps, - prediction_horizon=10, # TODO: make variable + prediction_horizon=20, # TODO: make variable num_samples=2, # TODO: make variable robot_present_and_future=robot_present_and_future, full_dist=True) diff --git a/trap/renderer.py b/trap/renderer.py new file mode 100644 index 0000000..c185fcc --- /dev/null +++ b/trap/renderer.py @@ -0,0 +1,91 @@ + +from argparse import Namespace +import logging +from multiprocessing import Event +import cv2 +import numpy as np + +import zmq + +from trap.frame_emitter import Frame + + +logger = logging.getLogger("trap.renderer") + +class Renderer: + def __init__(self, config: Namespace, is_running: Event): + self.config = config + self.is_running = is_running + + context = zmq.Context() + self.prediction_sock = context.socket(zmq.SUB) + self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! + self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'') + self.prediction_sock.connect(config.zmq_prediction_addr) + + self.frame_sock = context.socket(zmq.SUB) + self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!! + self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'') + self.frame_sock.connect(config.zmq_frame_addr) + + + H = np.loadtxt(self.config.homography, delimiter=',') + + self.inv_H = np.linalg.pinv(H) + + if not self.config.output_dir.exists(): + raise FileNotFoundError("Path does not exist") + + def run(self): + predictions = {} + i=0 + first_time = None + while self.is_running.is_set(): + i+=1 + frame: Frame = self.frame_sock.recv_pyobj() + try: + predictions = self.prediction_sock.recv_json(zmq.NOBLOCK) + except zmq.ZMQError as e: + logger.debug(f'reuse prediction') + + img = frame.img + for track_id, prediction in predictions.items(): + if not 'history' in prediction or not len(prediction['history']): + continue + coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0] + # logger.warning(f"{coords=}") + center = [int(p) for p in coords[-1]] + cv2.circle(img, center, 5, (0,255,0)) + + for ci in range(1, len(coords)): + start = [int(p) for p in coords[ci-1]] + end = [int(p) for p in coords[ci]] + cv2.line(img, start, end, (255,255,255), 2) + + if not 'predictions' in prediction or not len(prediction['predictions']): + continue + + for pred in prediction['predictions']: + pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0] + for ci in range(1, len(pred_coords)): + start = [int(p) for p in pred_coords[ci-1]] + end = [int(p) for p in pred_coords[ci]] + cv2.line(img, start, end, (0,0,255), 2) + + + + if first_time is None: + first_time = frame.time + + cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1) + + img_path = (self.config.output_dir / f"{i:05d}.png").resolve() + + cv2.imwrite(str(img_path), img) + logger.info('Stopping') + + + +def run_renderer(config: Namespace, is_running: Event): + renderer = Renderer(config, is_running) + renderer.run() \ No newline at end of file