Trajectron-plus-plus/trajectron/utils/trajectory_utils.py

49 lines
1.8 KiB
Python
Raw Normal View History

2020-01-13 18:55:45 +00:00
import numpy as np
def prediction_output_to_trajectories(prediction_output_dict,
dt,
max_h,
ph,
map=None,
prune_ph_to_future=False):
prediction_timesteps = prediction_output_dict.keys()
output_dict = dict()
histories_dict = dict()
futures_dict = dict()
for t in prediction_timesteps:
histories_dict[t] = dict()
output_dict[t] = dict()
futures_dict[t] = dict()
prediction_nodes = prediction_output_dict[t].keys()
for node in prediction_nodes:
predictions_output = prediction_output_dict[t][node]
position_state = {'position': ['x', 'y']}
2020-01-13 18:55:45 +00:00
history = node.get(np.array([t - max_h, t]), position_state) # History includes current pos
history = history[~np.isnan(history.sum(axis=1))]
future = node.get(np.array([t + 1, t + ph]), position_state)
future = future[~np.isnan(future.sum(axis=1))]
if prune_ph_to_future:
predictions_output = predictions_output[:, :, :future.shape[0]]
if predictions_output.shape[2] == 0:
2020-01-13 18:55:45 +00:00
continue
trajectory = predictions_output
2020-01-13 18:55:45 +00:00
if map is None:
histories_dict[t][node] = history
output_dict[t][node] = trajectory
futures_dict[t][node] = future
else:
histories_dict[t][node] = map.to_map_points(history)
output_dict[t][node] = map.to_map_points(trajectory)
futures_dict[t][node] = map.to_map_points(future)
return output_dict, histories_dict, futures_dict