Trajectron-plus-plus/experiments/nuScenes/evaluate.py

186 lines
9.1 KiB
Python

import sys
import os
import dill
import json
import argparse
import torch
import numpy as np
import pandas as pd
sys.path.append("../../trajectron")
from tqdm import tqdm
from model.model_registrar import ModelRegistrar
from model.trajectron import Trajectron
import evaluation
import utils
from scipy.interpolate import RectBivariateSpline
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
parser = argparse.ArgumentParser()
parser.add_argument("--model", help="model full path", type=str)
parser.add_argument("--checkpoint", help="model checkpoint to evaluate", type=int)
parser.add_argument("--data", help="full path to data file", type=str)
parser.add_argument("--output_path", help="path to output csv file", type=str)
parser.add_argument("--output_tag", help="name tag for output file", type=str)
parser.add_argument("--node_type", help="node type to evaluate", type=str)
parser.add_argument("--prediction_horizon", nargs='+', help="prediction horizon", type=int, default=None)
args = parser.parse_args()
def compute_road_violations(predicted_trajs, map, channel):
obs_map = 1 - map.data[..., channel, :, :] / 255
interp_obs_map = RectBivariateSpline(range(obs_map.shape[0]),
range(obs_map.shape[1]),
obs_map,
kx=1, ky=1)
old_shape = predicted_trajs.shape
pred_trajs_map = map.to_map_points(predicted_trajs.reshape((-1, 2)))
traj_obs_values = interp_obs_map(pred_trajs_map[:, 0], pred_trajs_map[:, 1], grid=False)
traj_obs_values = traj_obs_values.reshape((old_shape[0], old_shape[1], old_shape[2]))
num_viol_trajs = np.sum(traj_obs_values.max(axis=2) > 0, dtype=float)
return num_viol_trajs
def load_model(model_dir, env, ts=100):
model_registrar = ModelRegistrar(model_dir, 'cpu')
model_registrar.load_models(ts)
with open(os.path.join(model_dir, 'config.json'), 'r') as config_json:
hyperparams = json.load(config_json)
trajectron = Trajectron(model_registrar, hyperparams, None, 'cpu')
trajectron.set_environment(env)
trajectron.set_annealing_params()
return trajectron, hyperparams
if __name__ == "__main__":
with open(args.data, 'rb') as f:
env = dill.load(f, encoding='latin1')
eval_stg, hyperparams = load_model(args.model, env, ts=args.checkpoint)
if 'override_attention_radius' in hyperparams:
for attention_radius_override in hyperparams['override_attention_radius']:
node_type1, node_type2, attention_radius = attention_radius_override.split(' ')
env.attention_radius[(node_type1, node_type2)] = float(attention_radius)
scenes = env.scenes
print("-- Preparing Node Graph")
for scene in tqdm(scenes):
scene.calculate_scene_graph(env.attention_radius,
hyperparams['edge_addition_filter'],
hyperparams['edge_removal_filter'])
for ph in args.prediction_horizon:
print(f"Prediction Horizon: {ph}")
max_hl = hyperparams['maximum_history_length']
with torch.no_grad():
############### MOST LIKELY Z ###############
eval_ade_batch_errors = np.array([])
eval_fde_batch_errors = np.array([])
print("-- Evaluating GMM Z Mode (Most Likely)")
for scene in tqdm(scenes):
timesteps = np.arange(scene.timesteps)
predictions = eval_stg.predict(scene,
timesteps,
ph,
num_samples=1,
min_future_timesteps=8,
z_mode=True,
gmm_mode=True,
full_dist=False) # This will trigger grid sampling
batch_error_dict = evaluation.compute_batch_statistics(predictions,
scene.dt,
max_hl=max_hl,
ph=ph,
node_type_enum=env.NodeType,
map=None,
prune_ph_to_future=False,
kde=False)
eval_ade_batch_errors = np.hstack((eval_ade_batch_errors, batch_error_dict[args.node_type]['ade']))
eval_fde_batch_errors = np.hstack((eval_fde_batch_errors, batch_error_dict[args.node_type]['fde']))
print(np.mean(eval_fde_batch_errors))
pd.DataFrame({'value': eval_ade_batch_errors, 'metric': 'ade', 'type': 'ml'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_ade_most_likely_z.csv'))
pd.DataFrame({'value': eval_fde_batch_errors, 'metric': 'fde', 'type': 'ml'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_fde_most_likely_z.csv'))
############### FULL ###############
eval_ade_batch_errors = np.array([])
eval_fde_batch_errors = np.array([])
eval_kde_nll = np.array([])
eval_road_viols = np.array([])
print("-- Evaluating Full")
for scene in tqdm(scenes):
timesteps = np.arange(scene.timesteps)
predictions = eval_stg.predict(scene,
timesteps,
ph,
num_samples=2000,
min_future_timesteps=8,
z_mode=False,
gmm_mode=False,
full_dist=False)
if not predictions:
continue
prediction_dict, _, _ = utils.prediction_output_to_trajectories(predictions,
scene.dt,
max_hl,
ph,
prune_ph_to_future=False)
eval_road_viols_batch = []
for t in prediction_dict.keys():
for node in prediction_dict[t].keys():
if node.type == args.node_type:
viols = compute_road_violations(prediction_dict[t][node],
scene.map[args.node_type],
channel=0)
if viols == 2000:
viols = 0
eval_road_viols_batch.append(viols)
eval_road_viols = np.hstack((eval_road_viols, eval_road_viols_batch))
batch_error_dict = evaluation.compute_batch_statistics(predictions,
scene.dt,
max_hl=max_hl,
ph=ph,
node_type_enum=env.NodeType,
map=None,
prune_ph_to_future=False)
eval_ade_batch_errors = np.hstack((eval_ade_batch_errors, batch_error_dict[args.node_type]['ade']))
eval_fde_batch_errors = np.hstack((eval_fde_batch_errors, batch_error_dict[args.node_type]['fde']))
eval_kde_nll = np.hstack((eval_kde_nll, batch_error_dict[args.node_type]['kde']))
pd.DataFrame({'value': eval_ade_batch_errors, 'metric': 'ade', 'type': 'full'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_ade_full.csv'))
pd.DataFrame({'value': eval_fde_batch_errors, 'metric': 'fde', 'type': 'full'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_fde_full.csv'))
pd.DataFrame({'value': eval_kde_nll, 'metric': 'kde', 'type': 'full'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_kde_full.csv'))
pd.DataFrame({'value': eval_road_viols, 'metric': 'road_viols', 'type': 'full'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_rv_full.csv'))