Trajectron-plus-plus/trajectron/evaluation/evaluation.py

141 lines
6.2 KiB
Python
Raw Normal View History

2020-01-13 19:55:45 +01:00
import numpy as np
from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import binary_dilation
from scipy.stats import gaussian_kde
from trajectron.utils import prediction_output_to_trajectories
import trajectron.visualization as visualization
2020-01-13 19:55:45 +01:00
from matplotlib import pyplot as plt
def compute_ade(predicted_trajs, gt_traj):
error = np.linalg.norm(predicted_trajs - gt_traj, axis=-1)
ade = np.mean(error, axis=-1)
return ade.flatten()
2020-01-13 19:55:45 +01:00
def compute_fde(predicted_trajs, gt_traj):
final_error = np.linalg.norm(predicted_trajs[:, :, -1] - gt_traj[-1], axis=-1)
return final_error.flatten()
2020-01-13 19:55:45 +01:00
def compute_kde_nll(predicted_trajs, gt_traj):
kde_ll = 0.
log_pdf_lower_bound = -20
num_timesteps = gt_traj.shape[0]
num_batches = predicted_trajs.shape[0]
2020-01-13 19:55:45 +01:00
for batch_num in range(num_batches):
for timestep in range(num_timesteps):
try:
kde = gaussian_kde(predicted_trajs[batch_num, :, timestep].T)
pdf = np.clip(kde.logpdf(gt_traj[timestep].T), a_min=log_pdf_lower_bound, a_max=None)[0]
kde_ll += pdf / (num_timesteps * num_batches)
except np.linalg.LinAlgError:
kde_ll = np.nan
2020-01-13 19:55:45 +01:00
return -kde_ll
def compute_obs_violations(predicted_trajs, map):
obs_map = map.data
2020-01-13 19:55:45 +01:00
interp_obs_map = RectBivariateSpline(range(obs_map.shape[1]),
range(obs_map.shape[0]),
binary_dilation(obs_map.T, iterations=4),
2020-01-13 19:55:45 +01:00
kx=1, ky=1)
old_shape = predicted_trajs.shape
pred_trajs_map = map.to_map_points(predicted_trajs.reshape((-1, 2)))
traj_obs_values = interp_obs_map(pred_trajs_map[:, 0], pred_trajs_map[:, 1], grid=False)
traj_obs_values = traj_obs_values.reshape((old_shape[0], old_shape[1]))
num_viol_trajs = np.sum(traj_obs_values.max(axis=1) > 0, dtype=float)
return num_viol_trajs
def compute_batch_statistics(prediction_output_dict,
dt,
max_hl,
ph,
node_type_enum,
kde=True,
obs=False,
map=None,
prune_ph_to_future=False,
best_of=False):
2020-01-13 19:55:45 +01:00
(prediction_dict,
_,
futures_dict) = prediction_output_to_trajectories(prediction_output_dict,
dt,
max_hl,
ph,
prune_ph_to_future=prune_ph_to_future)
2020-01-13 19:55:45 +01:00
batch_error_dict = dict()
for node_type in node_type_enum:
batch_error_dict[node_type] = {'ade': list(), 'fde': list(), 'kde': list(), 'obs_viols': list()}
for t in prediction_dict.keys():
for node in prediction_dict[t].keys():
ade_errors = compute_ade(prediction_dict[t][node], futures_dict[t][node])
fde_errors = compute_fde(prediction_dict[t][node], futures_dict[t][node])
if kde:
kde_ll = compute_kde_nll(prediction_dict[t][node], futures_dict[t][node])
else:
kde_ll = 0
if obs:
obs_viols = compute_obs_violations(prediction_dict[t][node], map)
else:
obs_viols = 0
if best_of:
ade_errors = np.min(ade_errors, keepdims=True)
fde_errors = np.min(fde_errors, keepdims=True)
kde_ll = np.min(kde_ll)
2020-01-13 19:55:45 +01:00
batch_error_dict[node.type]['ade'].extend(list(ade_errors))
batch_error_dict[node.type]['fde'].extend(list(fde_errors))
batch_error_dict[node.type]['kde'].extend([kde_ll])
batch_error_dict[node.type]['obs_viols'].extend([obs_viols])
return batch_error_dict
def log_batch_errors(batch_errors_list, log_writer, namespace, curr_iter, bar_plot=[], box_plot=[]):
for node_type in batch_errors_list[0].keys():
for metric in batch_errors_list[0][node_type].keys():
metric_batch_error = []
for batch_errors in batch_errors_list:
metric_batch_error.extend(batch_errors[node_type][metric])
if len(metric_batch_error) > 0:
log_writer.add_histogram(f"{node_type.name}/{namespace}/{metric}", metric_batch_error, curr_iter)
log_writer.add_scalar(f"{node_type.name}/{namespace}/{metric}_mean", np.mean(metric_batch_error), curr_iter)
log_writer.add_scalar(f"{node_type.name}/{namespace}/{metric}_median", np.median(metric_batch_error), curr_iter)
if metric in bar_plot:
pd = {'dataset': [namespace] * len(metric_batch_error),
metric: metric_batch_error}
kde_barplot_fig, ax = plt.subplots(figsize=(5, 5))
visualization.visualization_utils.plot_barplots(ax, pd, 'dataset', metric)
log_writer.add_figure(f"{node_type.name}/{namespace}/{metric}_bar_plot", kde_barplot_fig, curr_iter)
if metric in box_plot:
mse_fde_pd = {'dataset': [namespace] * len(metric_batch_error),
metric: metric_batch_error}
fig, ax = plt.subplots(figsize=(5, 5))
visualization.visualization_utils.plot_boxplots(ax, mse_fde_pd, 'dataset', metric)
log_writer.add_figure(f"{node_type.name}/{namespace}/{metric}_box_plot", fig, curr_iter)
def print_batch_errors(batch_errors_list, namespace, curr_iter):
for node_type in batch_errors_list[0].keys():
for metric in batch_errors_list[0][node_type].keys():
metric_batch_error = []
for batch_errors in batch_errors_list:
metric_batch_error.extend(batch_errors[node_type][metric])
if len(metric_batch_error) > 0:
print(f"{curr_iter}: {node_type.name}/{namespace}/{metric}_mean", np.mean(metric_batch_error))
print(f"{curr_iter}: {node_type.name}/{namespace}/{metric}_median", np.median(metric_batch_error))