Test renderer to frames
This commit is contained in:
parent
3d34263a71
commit
821d06c9cf
4 changed files with 109 additions and 2 deletions
|
@ -1,6 +1,8 @@
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pyparsing import Optional
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,6 +29,7 @@ inference_parser = parser.add_argument_group('Inference')
|
||||||
connection_parser = parser.add_argument_group('Connection')
|
connection_parser = parser.add_argument_group('Connection')
|
||||||
frame_emitter_parser = parser.add_argument_group('Frame emitter')
|
frame_emitter_parser = parser.add_argument_group('Frame emitter')
|
||||||
tracker_parser = parser.add_argument_group('Tracker')
|
tracker_parser = parser.add_argument_group('Tracker')
|
||||||
|
render_parser = parser.add_argument_group('Renderer')
|
||||||
|
|
||||||
inference_parser.add_argument("--model_dir",
|
inference_parser.add_argument("--model_dir",
|
||||||
help="directory with the model to use for inference",
|
help="directory with the model to use for inference",
|
||||||
|
@ -106,7 +109,7 @@ inference_parser.add_argument("--eval_data_dict",
|
||||||
|
|
||||||
inference_parser.add_argument("--output_dir",
|
inference_parser.add_argument("--output_dir",
|
||||||
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
|
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
|
||||||
type=str,
|
type=Path,
|
||||||
default='./OUT/test_inference')
|
default='./OUT/test_inference')
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,3 +177,11 @@ tracker_parser.add_argument("--homography",
|
||||||
help="File with homography params",
|
help="File with homography params",
|
||||||
type=Path,
|
type=Path,
|
||||||
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
|
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
|
||||||
|
|
||||||
|
# Renderer
|
||||||
|
|
||||||
|
# render_parser.add_argument("--output-dir",
|
||||||
|
# help="Target image dir",
|
||||||
|
# type=Optional[Path],
|
||||||
|
# default=None)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ import sys
|
||||||
from trap.config import parser
|
from trap.config import parser
|
||||||
from trap.frame_emitter import run_frame_emitter
|
from trap.frame_emitter import run_frame_emitter
|
||||||
from trap.prediction_server import run_prediction_server
|
from trap.prediction_server import run_prediction_server
|
||||||
|
from trap.renderer import run_renderer
|
||||||
from trap.socket_forwarder import run_ws_forwarder
|
from trap.socket_forwarder import run_ws_forwarder
|
||||||
from trap.tracker import run_tracker
|
from trap.tracker import run_tracker
|
||||||
|
|
||||||
|
@ -52,6 +53,7 @@ def start():
|
||||||
ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
|
ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
|
||||||
ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
|
ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
|
||||||
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
||||||
|
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer'),
|
||||||
]
|
]
|
||||||
if not args.bypass_prediction:
|
if not args.bypass_prediction:
|
||||||
procs.append(
|
procs.append(
|
||||||
|
|
|
@ -4,8 +4,11 @@ import logging
|
||||||
from multiprocessing import Event, Queue
|
from multiprocessing import Event, Queue
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
import traceback
|
||||||
|
import warnings
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import dill
|
import dill
|
||||||
|
@ -301,7 +304,7 @@ class PredictionServer:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
dists, preds = trajectron.incremental_forward(input_dict,
|
dists, preds = trajectron.incremental_forward(input_dict,
|
||||||
maps,
|
maps,
|
||||||
prediction_horizon=10, # TODO: make variable
|
prediction_horizon=20, # TODO: make variable
|
||||||
num_samples=2, # TODO: make variable
|
num_samples=2, # TODO: make variable
|
||||||
robot_present_and_future=robot_present_and_future,
|
robot_present_and_future=robot_present_and_future,
|
||||||
full_dist=True)
|
full_dist=True)
|
||||||
|
|
91
trap/renderer.py
Normal file
91
trap/renderer.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
|
import logging
|
||||||
|
from multiprocessing import Event
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from trap.frame_emitter import Frame
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("trap.renderer")
|
||||||
|
|
||||||
|
class Renderer:
|
||||||
|
def __init__(self, config: Namespace, is_running: Event):
|
||||||
|
self.config = config
|
||||||
|
self.is_running = is_running
|
||||||
|
|
||||||
|
context = zmq.Context()
|
||||||
|
self.prediction_sock = context.socket(zmq.SUB)
|
||||||
|
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
|
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
|
self.prediction_sock.connect(config.zmq_prediction_addr)
|
||||||
|
|
||||||
|
self.frame_sock = context.socket(zmq.SUB)
|
||||||
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
|
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
|
self.frame_sock.connect(config.zmq_frame_addr)
|
||||||
|
|
||||||
|
|
||||||
|
H = np.loadtxt(self.config.homography, delimiter=',')
|
||||||
|
|
||||||
|
self.inv_H = np.linalg.pinv(H)
|
||||||
|
|
||||||
|
if not self.config.output_dir.exists():
|
||||||
|
raise FileNotFoundError("Path does not exist")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
predictions = {}
|
||||||
|
i=0
|
||||||
|
first_time = None
|
||||||
|
while self.is_running.is_set():
|
||||||
|
i+=1
|
||||||
|
frame: Frame = self.frame_sock.recv_pyobj()
|
||||||
|
try:
|
||||||
|
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
|
||||||
|
except zmq.ZMQError as e:
|
||||||
|
logger.debug(f'reuse prediction')
|
||||||
|
|
||||||
|
img = frame.img
|
||||||
|
for track_id, prediction in predictions.items():
|
||||||
|
if not 'history' in prediction or not len(prediction['history']):
|
||||||
|
continue
|
||||||
|
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||||
|
# logger.warning(f"{coords=}")
|
||||||
|
center = [int(p) for p in coords[-1]]
|
||||||
|
cv2.circle(img, center, 5, (0,255,0))
|
||||||
|
|
||||||
|
for ci in range(1, len(coords)):
|
||||||
|
start = [int(p) for p in coords[ci-1]]
|
||||||
|
end = [int(p) for p in coords[ci]]
|
||||||
|
cv2.line(img, start, end, (255,255,255), 2)
|
||||||
|
|
||||||
|
if not 'predictions' in prediction or not len(prediction['predictions']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for pred in prediction['predictions']:
|
||||||
|
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||||
|
for ci in range(1, len(pred_coords)):
|
||||||
|
start = [int(p) for p in pred_coords[ci-1]]
|
||||||
|
end = [int(p) for p in pred_coords[ci]]
|
||||||
|
cv2.line(img, start, end, (0,0,255), 2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if first_time is None:
|
||||||
|
first_time = frame.time
|
||||||
|
|
||||||
|
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||||
|
|
||||||
|
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
||||||
|
|
||||||
|
cv2.imwrite(str(img_path), img)
|
||||||
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_renderer(config: Namespace, is_running: Event):
|
||||||
|
renderer = Renderer(config, is_running)
|
||||||
|
renderer.run()
|
Loading…
Reference in a new issue