Trajectron-plus-plus/trajectron/visualization/visualization.py
2020-04-05 21:43:49 -04:00

89 lines
No EOL
3.3 KiB
Python

from utils import prediction_output_to_trajectories
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
import numpy as np
import seaborn as sns
def plot_trajectories(ax,
prediction_dict,
histories_dict,
futures_dict,
line_alpha=0.7,
line_width=0.2,
edge_width=2,
circle_edge_width=0.5,
node_circle_size=0.3,
batch_num=0,
kde=False):
cmap = ['k', 'b', 'y', 'g', 'r']
for node in histories_dict:
history = histories_dict[node]
future = futures_dict[node]
predictions = prediction_dict[node]
if np.isnan(history[-1]).any():
continue
ax.plot(history[:, 0], history[:, 1], 'k--')
for sample_num in range(prediction_dict[node].shape[1]):
if kde and predictions.shape[1] >= 50:
line_alpha = 0.2
for t in range(predictions.shape[2]):
sns.kdeplot(predictions[batch_num, :, t, 0], predictions[batch_num, :, t, 1],
ax=ax, shade=True, shade_lowest=False,
color=np.random.choice(cmap), alpha=0.8)
ax.plot(predictions[batch_num, sample_num, :, 0], predictions[batch_num, sample_num, :, 1],
color=cmap[node.type.value],
linewidth=line_width, alpha=line_alpha)
ax.plot(future[:, 0],
future[:, 1],
'w--',
path_effects=[pe.Stroke(linewidth=edge_width, foreground='k'), pe.Normal()])
# Current Node Position
circle = plt.Circle((history[-1, 0],
history[-1, 1]),
node_circle_size,
facecolor='g',
edgecolor='k',
lw=circle_edge_width,
zorder=3)
ax.add_artist(circle)
ax.axis('equal')
def visualize_prediction(ax,
prediction_output_dict,
dt,
max_hl,
ph,
robot_node=None,
map=None,
**kwargs):
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories(prediction_output_dict,
dt,
max_hl,
ph,
map=map)
assert(len(prediction_dict.keys()) <= 1)
if len(prediction_dict.keys()) == 0:
return
ts_key = list(prediction_dict.keys())[0]
prediction_dict = prediction_dict[ts_key]
histories_dict = histories_dict[ts_key]
futures_dict = futures_dict[ts_key]
if map is not None:
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs)