2020-01-13 19:55:45 +01:00
|
|
|
from utils import prediction_output_to_trajectories
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import matplotlib.patheffects as pe
|
2020-04-06 03:43:49 +02:00
|
|
|
import numpy as np
|
|
|
|
import seaborn as sns
|
2020-01-13 19:55:45 +01:00
|
|
|
|
|
|
|
|
|
|
|
def plot_trajectories(ax,
|
|
|
|
prediction_dict,
|
|
|
|
histories_dict,
|
|
|
|
futures_dict,
|
|
|
|
line_alpha=0.7,
|
|
|
|
line_width=0.2,
|
|
|
|
edge_width=2,
|
|
|
|
circle_edge_width=0.5,
|
2020-04-06 03:43:49 +02:00
|
|
|
node_circle_size=0.3,
|
|
|
|
batch_num=0,
|
|
|
|
kde=False):
|
2020-01-13 19:55:45 +01:00
|
|
|
|
|
|
|
cmap = ['k', 'b', 'y', 'g', 'r']
|
|
|
|
|
|
|
|
for node in histories_dict:
|
|
|
|
history = histories_dict[node]
|
|
|
|
future = futures_dict[node]
|
|
|
|
predictions = prediction_dict[node]
|
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
if np.isnan(history[-1]).any():
|
|
|
|
continue
|
2020-01-13 19:55:45 +01:00
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
ax.plot(history[:, 0], history[:, 1], 'k--')
|
|
|
|
|
|
|
|
for sample_num in range(prediction_dict[node].shape[1]):
|
|
|
|
|
|
|
|
if kde and predictions.shape[1] >= 50:
|
|
|
|
line_alpha = 0.2
|
|
|
|
for t in range(predictions.shape[2]):
|
|
|
|
sns.kdeplot(predictions[batch_num, :, t, 0], predictions[batch_num, :, t, 1],
|
|
|
|
ax=ax, shade=True, shade_lowest=False,
|
|
|
|
color=np.random.choice(cmap), alpha=0.8)
|
|
|
|
|
|
|
|
ax.plot(predictions[batch_num, sample_num, :, 0], predictions[batch_num, sample_num, :, 1],
|
2020-01-13 19:55:45 +01:00
|
|
|
color=cmap[node.type.value],
|
|
|
|
linewidth=line_width, alpha=line_alpha)
|
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
ax.plot(future[:, 0],
|
|
|
|
future[:, 1],
|
2020-01-13 19:55:45 +01:00
|
|
|
'w--',
|
|
|
|
path_effects=[pe.Stroke(linewidth=edge_width, foreground='k'), pe.Normal()])
|
|
|
|
|
|
|
|
# Current Node Position
|
2020-04-06 03:43:49 +02:00
|
|
|
circle = plt.Circle((history[-1, 0],
|
|
|
|
history[-1, 1]),
|
2020-01-13 19:55:45 +01:00
|
|
|
node_circle_size,
|
|
|
|
facecolor='g',
|
|
|
|
edgecolor='k',
|
|
|
|
lw=circle_edge_width,
|
|
|
|
zorder=3)
|
|
|
|
ax.add_artist(circle)
|
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
ax.axis('equal')
|
2020-01-13 19:55:45 +01:00
|
|
|
|
|
|
|
|
|
|
|
def visualize_prediction(ax,
|
|
|
|
prediction_output_dict,
|
|
|
|
dt,
|
|
|
|
max_hl,
|
|
|
|
ph,
|
|
|
|
robot_node=None,
|
|
|
|
map=None,
|
|
|
|
**kwargs):
|
|
|
|
|
|
|
|
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories(prediction_output_dict,
|
|
|
|
dt,
|
|
|
|
max_hl,
|
|
|
|
ph,
|
|
|
|
map=map)
|
|
|
|
|
|
|
|
assert(len(prediction_dict.keys()) <= 1)
|
|
|
|
if len(prediction_dict.keys()) == 0:
|
|
|
|
return
|
|
|
|
ts_key = list(prediction_dict.keys())[0]
|
|
|
|
|
|
|
|
prediction_dict = prediction_dict[ts_key]
|
|
|
|
histories_dict = histories_dict[ts_key]
|
|
|
|
futures_dict = futures_dict[ts_key]
|
|
|
|
|
|
|
|
if map is not None:
|
2020-04-06 03:43:49 +02:00
|
|
|
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
|
2020-01-13 19:55:45 +01:00
|
|
|
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs)
|