Test renderer to frames

This commit is contained in:
Ruben van de Ven 2023-10-16 16:49:20 +02:00
parent 3d34263a71
commit 821d06c9cf
4 changed files with 109 additions and 2 deletions

View file

@ -1,6 +1,8 @@
import argparse
from pathlib import Path
from pyparsing import Optional
parser = argparse.ArgumentParser()
@ -27,6 +29,7 @@ inference_parser = parser.add_argument_group('Inference')
connection_parser = parser.add_argument_group('Connection')
frame_emitter_parser = parser.add_argument_group('Frame emitter')
tracker_parser = parser.add_argument_group('Tracker')
render_parser = parser.add_argument_group('Renderer')
inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference",
@ -106,7 +109,7 @@ inference_parser.add_argument("--eval_data_dict",
inference_parser.add_argument("--output_dir",
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
type=str,
type=Path,
default='./OUT/test_inference')
@ -174,3 +177,11 @@ tracker_parser.add_argument("--homography",
help="File with homography params",
type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
# Renderer
# render_parser.add_argument("--output-dir",
# help="Target image dir",
# type=Optional[Path],
# default=None)

View file

@ -5,6 +5,7 @@ import sys
from trap.config import parser
from trap.frame_emitter import run_frame_emitter
from trap.prediction_server import run_prediction_server
from trap.renderer import run_renderer
from trap.socket_forwarder import run_ws_forwarder
from trap.tracker import run_tracker
@ -52,6 +53,7 @@ def start():
ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer'),
]
if not args.bypass_prediction:
procs.append(

View file

@ -4,8 +4,11 @@ import logging
from multiprocessing import Event, Queue
import os
import pickle
import sys
import time
import json
import traceback
import warnings
import pandas as pd
import torch
import dill
@ -301,7 +304,7 @@ class PredictionServer:
start = time.time()
dists, preds = trajectron.incremental_forward(input_dict,
maps,
prediction_horizon=10, # TODO: make variable
prediction_horizon=20, # TODO: make variable
num_samples=2, # TODO: make variable
robot_present_and_future=robot_present_and_future,
full_dist=True)

91
trap/renderer.py Normal file
View file

@ -0,0 +1,91 @@
from argparse import Namespace
import logging
from multiprocessing import Event
import cv2
import numpy as np
import zmq
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.renderer")
class Renderer:
def __init__(self, config: Namespace, is_running: Event):
self.config = config
self.is_running = is_running
context = zmq.Context()
self.prediction_sock = context.socket(zmq.SUB)
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_sock.connect(config.zmq_prediction_addr)
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.connect(config.zmq_frame_addr)
H = np.loadtxt(self.config.homography, delimiter=',')
self.inv_H = np.linalg.pinv(H)
if not self.config.output_dir.exists():
raise FileNotFoundError("Path does not exist")
def run(self):
predictions = {}
i=0
first_time = None
while self.is_running.is_set():
i+=1
frame: Frame = self.frame_sock.recv_pyobj()
try:
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
except zmq.ZMQError as e:
logger.debug(f'reuse prediction')
img = frame.img
for track_id, prediction in predictions.items():
if not 'history' in prediction or not len(prediction['history']):
continue
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
# logger.warning(f"{coords=}")
center = [int(p) for p in coords[-1]]
cv2.circle(img, center, 5, (0,255,0))
for ci in range(1, len(coords)):
start = [int(p) for p in coords[ci-1]]
end = [int(p) for p in coords[ci]]
cv2.line(img, start, end, (255,255,255), 2)
if not 'predictions' in prediction or not len(prediction['predictions']):
continue
for pred in prediction['predictions']:
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
for ci in range(1, len(pred_coords)):
start = [int(p) for p in pred_coords[ci-1]]
end = [int(p) for p in pred_coords[ci]]
cv2.line(img, start, end, (0,0,255), 2)
if first_time is None:
first_time = frame.time
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
cv2.imwrite(str(img_path), img)
logger.info('Stopping')
def run_renderer(config: Namespace, is_running: Event):
renderer = Renderer(config, is_running)
renderer.run()