140 lines
6.2 KiB
Python
140 lines
6.2 KiB
Python
import numpy as np
|
|
from scipy.interpolate import RectBivariateSpline
|
|
from scipy.ndimage import binary_dilation
|
|
from scipy.stats import gaussian_kde
|
|
from trajectron.utils import prediction_output_to_trajectories
|
|
import trajectron.visualization as visualization
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
def compute_ade(predicted_trajs, gt_traj):
|
|
error = np.linalg.norm(predicted_trajs - gt_traj, axis=-1)
|
|
ade = np.mean(error, axis=-1)
|
|
return ade.flatten()
|
|
|
|
|
|
def compute_fde(predicted_trajs, gt_traj):
|
|
final_error = np.linalg.norm(predicted_trajs[:, :, -1] - gt_traj[-1], axis=-1)
|
|
return final_error.flatten()
|
|
|
|
|
|
def compute_kde_nll(predicted_trajs, gt_traj):
|
|
kde_ll = 0.
|
|
log_pdf_lower_bound = -20
|
|
num_timesteps = gt_traj.shape[0]
|
|
num_batches = predicted_trajs.shape[0]
|
|
|
|
for batch_num in range(num_batches):
|
|
for timestep in range(num_timesteps):
|
|
try:
|
|
kde = gaussian_kde(predicted_trajs[batch_num, :, timestep].T)
|
|
pdf = np.clip(kde.logpdf(gt_traj[timestep].T), a_min=log_pdf_lower_bound, a_max=None)[0]
|
|
kde_ll += pdf / (num_timesteps * num_batches)
|
|
except np.linalg.LinAlgError:
|
|
kde_ll = np.nan
|
|
|
|
return -kde_ll
|
|
|
|
|
|
def compute_obs_violations(predicted_trajs, map):
|
|
obs_map = map.data
|
|
|
|
interp_obs_map = RectBivariateSpline(range(obs_map.shape[1]),
|
|
range(obs_map.shape[0]),
|
|
binary_dilation(obs_map.T, iterations=4),
|
|
kx=1, ky=1)
|
|
|
|
old_shape = predicted_trajs.shape
|
|
pred_trajs_map = map.to_map_points(predicted_trajs.reshape((-1, 2)))
|
|
|
|
traj_obs_values = interp_obs_map(pred_trajs_map[:, 0], pred_trajs_map[:, 1], grid=False)
|
|
traj_obs_values = traj_obs_values.reshape((old_shape[0], old_shape[1]))
|
|
num_viol_trajs = np.sum(traj_obs_values.max(axis=1) > 0, dtype=float)
|
|
|
|
return num_viol_trajs
|
|
|
|
|
|
def compute_batch_statistics(prediction_output_dict,
|
|
dt,
|
|
max_hl,
|
|
ph,
|
|
node_type_enum,
|
|
kde=True,
|
|
obs=False,
|
|
map=None,
|
|
prune_ph_to_future=False,
|
|
best_of=False):
|
|
|
|
(prediction_dict,
|
|
_,
|
|
futures_dict) = prediction_output_to_trajectories(prediction_output_dict,
|
|
dt,
|
|
max_hl,
|
|
ph,
|
|
prune_ph_to_future=prune_ph_to_future)
|
|
|
|
batch_error_dict = dict()
|
|
for node_type in node_type_enum:
|
|
batch_error_dict[node_type] = {'ade': list(), 'fde': list(), 'kde': list(), 'obs_viols': list()}
|
|
|
|
for t in prediction_dict.keys():
|
|
for node in prediction_dict[t].keys():
|
|
ade_errors = compute_ade(prediction_dict[t][node], futures_dict[t][node])
|
|
fde_errors = compute_fde(prediction_dict[t][node], futures_dict[t][node])
|
|
if kde:
|
|
kde_ll = compute_kde_nll(prediction_dict[t][node], futures_dict[t][node])
|
|
else:
|
|
kde_ll = 0
|
|
if obs:
|
|
obs_viols = compute_obs_violations(prediction_dict[t][node], map)
|
|
else:
|
|
obs_viols = 0
|
|
if best_of:
|
|
ade_errors = np.min(ade_errors, keepdims=True)
|
|
fde_errors = np.min(fde_errors, keepdims=True)
|
|
kde_ll = np.min(kde_ll)
|
|
batch_error_dict[node.type]['ade'].extend(list(ade_errors))
|
|
batch_error_dict[node.type]['fde'].extend(list(fde_errors))
|
|
batch_error_dict[node.type]['kde'].extend([kde_ll])
|
|
batch_error_dict[node.type]['obs_viols'].extend([obs_viols])
|
|
|
|
return batch_error_dict
|
|
|
|
|
|
def log_batch_errors(batch_errors_list, log_writer, namespace, curr_iter, bar_plot=[], box_plot=[]):
|
|
for node_type in batch_errors_list[0].keys():
|
|
for metric in batch_errors_list[0][node_type].keys():
|
|
metric_batch_error = []
|
|
for batch_errors in batch_errors_list:
|
|
metric_batch_error.extend(batch_errors[node_type][metric])
|
|
|
|
if len(metric_batch_error) > 0:
|
|
log_writer.add_histogram(f"{node_type.name}/{namespace}/{metric}", metric_batch_error, curr_iter)
|
|
log_writer.add_scalar(f"{node_type.name}/{namespace}/{metric}_mean", np.mean(metric_batch_error), curr_iter)
|
|
log_writer.add_scalar(f"{node_type.name}/{namespace}/{metric}_median", np.median(metric_batch_error), curr_iter)
|
|
|
|
if metric in bar_plot:
|
|
pd = {'dataset': [namespace] * len(metric_batch_error),
|
|
metric: metric_batch_error}
|
|
kde_barplot_fig, ax = plt.subplots(figsize=(5, 5))
|
|
visualization.visualization_utils.plot_barplots(ax, pd, 'dataset', metric)
|
|
log_writer.add_figure(f"{node_type.name}/{namespace}/{metric}_bar_plot", kde_barplot_fig, curr_iter)
|
|
|
|
if metric in box_plot:
|
|
mse_fde_pd = {'dataset': [namespace] * len(metric_batch_error),
|
|
metric: metric_batch_error}
|
|
fig, ax = plt.subplots(figsize=(5, 5))
|
|
visualization.visualization_utils.plot_boxplots(ax, mse_fde_pd, 'dataset', metric)
|
|
log_writer.add_figure(f"{node_type.name}/{namespace}/{metric}_box_plot", fig, curr_iter)
|
|
|
|
|
|
def print_batch_errors(batch_errors_list, namespace, curr_iter):
|
|
for node_type in batch_errors_list[0].keys():
|
|
for metric in batch_errors_list[0][node_type].keys():
|
|
metric_batch_error = []
|
|
for batch_errors in batch_errors_list:
|
|
metric_batch_error.extend(batch_errors[node_type][metric])
|
|
|
|
if len(metric_batch_error) > 0:
|
|
print(f"{curr_iter}: {node_type.name}/{namespace}/{metric}_mean", np.mean(metric_batch_error))
|
|
print(f"{curr_iter}: {node_type.name}/{namespace}/{metric}_median", np.median(metric_batch_error))
|