Predictor options now configurable and rendered

This commit is contained in:
Ruben van de Ven 2024-04-29 18:35:22 +02:00
parent 8d9c7d3486
commit 7710794bad
3 changed files with 47 additions and 16 deletions

View file

@ -156,6 +156,28 @@ inference_parser.add_argument("--smooth-predictions",
help="Smooth the predicted tracks", help="Smooth the predicted tracks",
action='store_true') action='store_true')
inference_parser.add_argument('--prediction-horizon',
help='Trajectron.incremental_forward parameter',
type=int,
default=30)
inference_parser.add_argument('--num-samples',
help='Trajectron.incremental_forward parameter',
type=int,
default=5)
inference_parser.add_argument("--full-dist",
help="Trajectron.incremental_forward parameter",
type=bool,
default=False)
inference_parser.add_argument("--gmm-mode",
help="Trajectron.incremental_forward parameter",
type=bool,
default=True)
inference_parser.add_argument("--z-mode",
help="Trajectron.incremental_forward parameter",
type=bool,
default=False)
# Internal connections. # Internal connections.
connection_parser.add_argument('--zmq-trajectory-addr', connection_parser.add_argument('--zmq-trajectory-addr',

View file

@ -191,7 +191,7 @@ class PredictionServer:
# You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1], # You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1],
# so that you can immediately start incremental inference from the 3rd timestep onwards. # so that you can immediately start incremental inference from the 3rd timestep onwards.
init_timestep = 1 init_timestep = 2
eval_scene = eval_env.scenes[scene_idx] eval_scene = eval_env.scenes[scene_idx]
online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep) online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep)
@ -314,24 +314,25 @@ class PredictionServer:
maps = get_maps_for_input(input_dict, eval_scene, hyperparams) maps = get_maps_for_input(input_dict, eval_scene, hyperparams)
# print(maps) # print(maps)
robot_present_and_future = None # robot_present_and_future = None
if eval_scene.robot is not None and hyperparams['incl_robot_node']: # if eval_scene.robot is not None and hyperparams['incl_robot_node']:
robot_present_and_future = eval_scene.robot.get(np.array([timestep, # robot_present_and_future = eval_scene.robot.get(np.array([timestep,
timestep + hyperparams['prediction_horizon']]), # timestep + hyperparams['prediction_horizon']]),
hyperparams['state'][eval_scene.robot.type], # hyperparams['state'][eval_scene.robot.type],
padding=0.0) # padding=0.0)
robot_present_and_future = np.stack([robot_present_and_future, robot_present_and_future], axis=0) # robot_present_and_future = np.stack([robot_present_and_future, robot_present_and_future], axis=0)
# robot_present_and_future += adjustment # # robot_present_and_future += adjustment
start = time.time() start = time.time()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py warnings.simplefilter('ignore') # prevent deluge of UserWarning from torch's rrn.py
dists, preds = trajectron.incremental_forward(input_dict, dists, preds = trajectron.incremental_forward(input_dict,
maps, maps,
prediction_horizon=125, # TODO: make variable prediction_horizon=self.config.prediction_horizon, # TODO: make variable
num_samples=5, # TODO: make variable num_samples=self.config.num_samples, # TODO: make variable
robot_present_and_future=robot_present_and_future, full_dist=self.config.full_dist,
full_dist=True) gmm_mode=self.config.gmm_mode,
z_mode=self.config.z_mode)
end = time.time() end = time.time()
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start, logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start,
1. / (end - start), len(trajectron.nodes), 1. / (end - start), len(trajectron.nodes),

View file

@ -99,7 +99,7 @@ class Renderer:
if first_time is None: if first_time is None:
first_time = frame.time first_time = frame.time
decorate_frame(frame, prediction_frame, first_time) decorate_frame(frame, prediction_frame, first_time, self.config)
img_path = (self.config.output_dir / f"{i:05d}.png").resolve() img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
@ -132,7 +132,7 @@ colorset = [(0, 0, 0),
] ]
def decorate_frame(frame: Frame, prediction_frame: Frame, first_time) -> np.array: def decorate_frame(frame: Frame, prediction_frame: Frame, first_time: float, config: Namespace) -> np.array:
frame.img frame.img
overlay = np.zeros(frame.img.shape, np.uint8) overlay = np.zeros(frame.img.shape, np.uint8)
@ -225,8 +225,16 @@ def decorate_frame(frame: Frame, prediction_frame: Frame, first_time) -> np.arra
cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()])}", (580,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1) cv2.putText(img, f"h: {np.average([len(t.history or []) for t in prediction_frame.tracks.values()])}", (580,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()])}", (660,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1) cv2.putText(img, f"ph: {np.average([len(t.predictor_history or []) for t in prediction_frame.tracks.values()])}", (660,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()])}", (740,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1) cv2.putText(img, f"p: {np.average([len(t.predictions or []) for t in prediction_frame.tracks.values()])}", (740,17), cv2.FONT_HERSHEY_PLAIN, 1, info_color, 1)
options = []
for option in ['prediction_horizon','num_samples','full_dist','gmm_mode','z_mode', 'model_dir']:
options.append(f"{option}: {config.__dict__[option]}")
return img cv2.putText(img, options.pop(-1), (20,img.shape[0]-30), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
cv2.putText(img, " | ".join(options), (20,img.shape[0]-10), cv2.FONT_HERSHEY_PLAIN, 1, base_color, 1)
return img
def run_renderer(config: Namespace, is_running: Event): def run_renderer(config: Namespace, is_running: Event):