Test renderer to frames

This commit is contained in:
Ruben van de Ven 2023-10-16 16:49:20 +02:00
parent 3d34263a71
commit 821d06c9cf
4 changed files with 109 additions and 2 deletions

View file

@ -1,6 +1,8 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from pyparsing import Optional
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -27,6 +29,7 @@ inference_parser = parser.add_argument_group('Inference')
connection_parser = parser.add_argument_group('Connection') connection_parser = parser.add_argument_group('Connection')
frame_emitter_parser = parser.add_argument_group('Frame emitter') frame_emitter_parser = parser.add_argument_group('Frame emitter')
tracker_parser = parser.add_argument_group('Tracker') tracker_parser = parser.add_argument_group('Tracker')
render_parser = parser.add_argument_group('Renderer')
inference_parser.add_argument("--model_dir", inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference", help="directory with the model to use for inference",
@ -106,7 +109,7 @@ inference_parser.add_argument("--eval_data_dict",
inference_parser.add_argument("--output_dir", inference_parser.add_argument("--output_dir",
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)", help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
type=str, type=Path,
default='./OUT/test_inference') default='./OUT/test_inference')
@ -174,3 +177,11 @@ tracker_parser.add_argument("--homography",
help="File with homography params", help="File with homography params",
type=Path, type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt') default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
# Renderer
# render_parser.add_argument("--output-dir",
# help="Target image dir",
# type=Optional[Path],
# default=None)

View file

@ -5,6 +5,7 @@ import sys
from trap.config import parser from trap.config import parser
from trap.frame_emitter import run_frame_emitter from trap.frame_emitter import run_frame_emitter
from trap.prediction_server import run_prediction_server from trap.prediction_server import run_prediction_server
from trap.renderer import run_renderer
from trap.socket_forwarder import run_ws_forwarder from trap.socket_forwarder import run_ws_forwarder
from trap.tracker import run_tracker from trap.tracker import run_tracker
@ -52,6 +53,7 @@ def start():
ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'), ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'), ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'), ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer'),
] ]
if not args.bypass_prediction: if not args.bypass_prediction:
procs.append( procs.append(

View file

@ -4,8 +4,11 @@ import logging
from multiprocessing import Event, Queue from multiprocessing import Event, Queue
import os import os
import pickle import pickle
import sys
import time import time
import json import json
import traceback
import warnings
import pandas as pd import pandas as pd
import torch import torch
import dill import dill
@ -301,7 +304,7 @@ class PredictionServer:
start = time.time() start = time.time()
dists, preds = trajectron.incremental_forward(input_dict, dists, preds = trajectron.incremental_forward(input_dict,
maps, maps,
prediction_horizon=10, # TODO: make variable prediction_horizon=20, # TODO: make variable
num_samples=2, # TODO: make variable num_samples=2, # TODO: make variable
robot_present_and_future=robot_present_and_future, robot_present_and_future=robot_present_and_future,
full_dist=True) full_dist=True)

91
trap/renderer.py Normal file
View file

@ -0,0 +1,91 @@
from argparse import Namespace
import logging
from multiprocessing import Event
import cv2
import numpy as np
import zmq
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.renderer")
class Renderer:
def __init__(self, config: Namespace, is_running: Event):
self.config = config
self.is_running = is_running
context = zmq.Context()
self.prediction_sock = context.socket(zmq.SUB)
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_sock.connect(config.zmq_prediction_addr)
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.connect(config.zmq_frame_addr)
H = np.loadtxt(self.config.homography, delimiter=',')
self.inv_H = np.linalg.pinv(H)
if not self.config.output_dir.exists():
raise FileNotFoundError("Path does not exist")
def run(self):
predictions = {}
i=0
first_time = None
while self.is_running.is_set():
i+=1
frame: Frame = self.frame_sock.recv_pyobj()
try:
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
except zmq.ZMQError as e:
logger.debug(f'reuse prediction')
img = frame.img
for track_id, prediction in predictions.items():
if not 'history' in prediction or not len(prediction['history']):
continue
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
# logger.warning(f"{coords=}")
center = [int(p) for p in coords[-1]]
cv2.circle(img, center, 5, (0,255,0))
for ci in range(1, len(coords)):
start = [int(p) for p in coords[ci-1]]
end = [int(p) for p in coords[ci]]
cv2.line(img, start, end, (255,255,255), 2)
if not 'predictions' in prediction or not len(prediction['predictions']):
continue
for pred in prediction['predictions']:
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
for ci in range(1, len(pred_coords)):
start = [int(p) for p in pred_coords[ci-1]]
end = [int(p) for p in pred_coords[ci]]
cv2.line(img, start, end, (0,0,255), 2)
if first_time is None:
first_time = frame.time
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
cv2.imwrite(str(img_path), img)
logger.info('Stopping')
def run_renderer(config: Namespace, is_running: Event):
renderer = Renderer(config, is_running)
renderer.run()