2023-10-09 20:27:29 +02:00
|
|
|
from trajectron.utils import prediction_output_to_trajectories
|
2020-12-10 04:42:06 +01:00
|
|
|
from scipy import linalg
|
2020-01-13 19:55:45 +01:00
|
|
|
import matplotlib.pyplot as plt
|
2020-12-10 04:42:06 +01:00
|
|
|
import matplotlib.patches as patches
|
2020-01-13 19:55:45 +01:00
|
|
|
import matplotlib.patheffects as pe
|
2020-04-06 03:43:49 +02:00
|
|
|
import numpy as np
|
|
|
|
import seaborn as sns
|
2020-01-13 19:55:45 +01:00
|
|
|
|
|
|
|
|
|
|
|
def plot_trajectories(ax,
|
|
|
|
prediction_dict,
|
|
|
|
histories_dict,
|
|
|
|
futures_dict,
|
|
|
|
line_alpha=0.7,
|
|
|
|
line_width=0.2,
|
|
|
|
edge_width=2,
|
|
|
|
circle_edge_width=0.5,
|
2020-04-06 03:43:49 +02:00
|
|
|
node_circle_size=0.3,
|
|
|
|
batch_num=0,
|
2024-12-13 10:38:12 +01:00
|
|
|
kde=False,
|
|
|
|
node_indexes=None):
|
2020-01-13 19:55:45 +01:00
|
|
|
|
|
|
|
cmap = ['k', 'b', 'y', 'g', 'r']
|
|
|
|
|
|
|
|
for node in histories_dict:
|
|
|
|
history = histories_dict[node]
|
|
|
|
future = futures_dict[node]
|
|
|
|
predictions = prediction_dict[node]
|
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
if np.isnan(history[-1]).any():
|
|
|
|
continue
|
2020-01-13 19:55:45 +01:00
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
ax.plot(history[:, 0], history[:, 1], 'k--')
|
|
|
|
|
|
|
|
for sample_num in range(prediction_dict[node].shape[1]):
|
|
|
|
|
|
|
|
if kde and predictions.shape[1] >= 50:
|
|
|
|
line_alpha = 0.2
|
|
|
|
for t in range(predictions.shape[2]):
|
|
|
|
sns.kdeplot(predictions[batch_num, :, t, 0], predictions[batch_num, :, t, 1],
|
|
|
|
ax=ax, shade=True, shade_lowest=False,
|
|
|
|
color=np.random.choice(cmap), alpha=0.8)
|
|
|
|
|
2024-12-13 10:38:12 +01:00
|
|
|
if not node_indexes:
|
|
|
|
color = cmap[node.type.value]
|
|
|
|
else:
|
|
|
|
color = cmap[(node_indexes[node]+1) % len(cmap)]
|
2020-04-06 03:43:49 +02:00
|
|
|
ax.plot(predictions[batch_num, sample_num, :, 0], predictions[batch_num, sample_num, :, 1],
|
2024-12-13 10:38:12 +01:00
|
|
|
color=color,
|
2020-01-13 19:55:45 +01:00
|
|
|
linewidth=line_width, alpha=line_alpha)
|
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
ax.plot(future[:, 0],
|
|
|
|
future[:, 1],
|
2020-01-13 19:55:45 +01:00
|
|
|
'w--',
|
|
|
|
path_effects=[pe.Stroke(linewidth=edge_width, foreground='k'), pe.Normal()])
|
|
|
|
|
|
|
|
# Current Node Position
|
2020-04-06 03:43:49 +02:00
|
|
|
circle = plt.Circle((history[-1, 0],
|
|
|
|
history[-1, 1]),
|
2020-01-13 19:55:45 +01:00
|
|
|
node_circle_size,
|
|
|
|
facecolor='g',
|
|
|
|
edgecolor='k',
|
|
|
|
lw=circle_edge_width,
|
|
|
|
zorder=3)
|
|
|
|
ax.add_artist(circle)
|
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
ax.axis('equal')
|
2020-01-13 19:55:45 +01:00
|
|
|
|
|
|
|
|
|
|
|
def visualize_prediction(ax,
|
|
|
|
prediction_output_dict,
|
|
|
|
dt,
|
|
|
|
max_hl,
|
|
|
|
ph,
|
|
|
|
robot_node=None,
|
|
|
|
map=None,
|
|
|
|
**kwargs):
|
|
|
|
|
|
|
|
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories(prediction_output_dict,
|
|
|
|
dt,
|
|
|
|
max_hl,
|
|
|
|
ph,
|
|
|
|
map=map)
|
|
|
|
|
|
|
|
assert(len(prediction_dict.keys()) <= 1)
|
|
|
|
if len(prediction_dict.keys()) == 0:
|
|
|
|
return
|
|
|
|
ts_key = list(prediction_dict.keys())[0]
|
|
|
|
|
|
|
|
prediction_dict = prediction_dict[ts_key]
|
|
|
|
histories_dict = histories_dict[ts_key]
|
|
|
|
futures_dict = futures_dict[ts_key]
|
|
|
|
|
2024-12-13 10:38:12 +01:00
|
|
|
node_indexes={node: nr for nr, node in enumerate(prediction_dict.keys())}
|
|
|
|
|
2020-01-13 19:55:45 +01:00
|
|
|
if map is not None:
|
2020-04-06 03:43:49 +02:00
|
|
|
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
|
2024-12-13 10:38:12 +01:00
|
|
|
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, node_indexes=node_indexes, *kwargs)
|
2020-12-10 04:42:06 +01:00
|
|
|
|
|
|
|
|
|
|
|
def visualize_distribution(ax,
|
|
|
|
prediction_distribution_dict,
|
|
|
|
map=None,
|
|
|
|
pi_threshold=0.05,
|
|
|
|
**kwargs):
|
|
|
|
if map is not None:
|
|
|
|
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
|
|
|
|
|
|
|
|
for node, pred_dist in prediction_distribution_dict.items():
|
|
|
|
if pred_dist.mus.shape[:2] != (1, 1):
|
|
|
|
return
|
|
|
|
|
|
|
|
means = pred_dist.mus.squeeze().cpu().numpy()
|
|
|
|
covs = pred_dist.get_covariance_matrix().squeeze().cpu().numpy()
|
|
|
|
pis = pred_dist.pis_cat_dist.probs.squeeze().cpu().numpy()
|
|
|
|
|
|
|
|
for timestep in range(means.shape[0]):
|
|
|
|
for z_val in range(means.shape[1]):
|
|
|
|
mean = means[timestep, z_val]
|
|
|
|
covar = covs[timestep, z_val]
|
|
|
|
pi = pis[timestep, z_val]
|
|
|
|
|
|
|
|
if pi < pi_threshold:
|
|
|
|
continue
|
|
|
|
|
|
|
|
v, w = linalg.eigh(covar)
|
|
|
|
v = 2. * np.sqrt(2.) * np.sqrt(v)
|
|
|
|
u = w[0] / linalg.norm(w[0])
|
|
|
|
|
|
|
|
# Plot an ellipse to show the Gaussian component
|
|
|
|
angle = np.arctan(u[1] / u[0])
|
|
|
|
angle = 180. * angle / np.pi # convert to degrees
|
|
|
|
ell = patches.Ellipse(mean, v[0], v[1], 180. + angle, color='blue' if node.type.name == 'VEHICLE' else 'orange')
|
|
|
|
ell.set_edgecolor(None)
|
|
|
|
ell.set_clip_box(ax.bbox)
|
|
|
|
ell.set_alpha(pi/10)
|
|
|
|
ax.add_artist(ell)
|