Trajectron-plus-plus/experiments/nuScenes/evaluate.py

186 lines
9.1 KiB
Python
Raw Permalink Normal View History

import sys
import os
import dill
import json
import argparse
import torch
import numpy as np
import pandas as pd
sys.path.append("../../trajectron")
from tqdm import tqdm
from model.model_registrar import ModelRegistrar
from model.trajectron import Trajectron
import evaluation
import utils
from scipy.interpolate import RectBivariateSpline
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
parser = argparse.ArgumentParser()
parser.add_argument("--model", help="model full path", type=str)
parser.add_argument("--checkpoint", help="model checkpoint to evaluate", type=int)
parser.add_argument("--data", help="full path to data file", type=str)
parser.add_argument("--output_path", help="path to output csv file", type=str)
parser.add_argument("--output_tag", help="name tag for output file", type=str)
parser.add_argument("--node_type", help="node type to evaluate", type=str)
parser.add_argument("--prediction_horizon", nargs='+', help="prediction horizon", type=int, default=None)
args = parser.parse_args()
def compute_road_violations(predicted_trajs, map, channel):
obs_map = 1 - map.data[..., channel, :, :] / 255
interp_obs_map = RectBivariateSpline(range(obs_map.shape[0]),
range(obs_map.shape[1]),
obs_map,
kx=1, ky=1)
old_shape = predicted_trajs.shape
pred_trajs_map = map.to_map_points(predicted_trajs.reshape((-1, 2)))
traj_obs_values = interp_obs_map(pred_trajs_map[:, 0], pred_trajs_map[:, 1], grid=False)
traj_obs_values = traj_obs_values.reshape((old_shape[0], old_shape[1], old_shape[2]))
num_viol_trajs = np.sum(traj_obs_values.max(axis=2) > 0, dtype=float)
return num_viol_trajs
def load_model(model_dir, env, ts=100):
model_registrar = ModelRegistrar(model_dir, 'cpu')
model_registrar.load_models(ts)
with open(os.path.join(model_dir, 'config.json'), 'r') as config_json:
hyperparams = json.load(config_json)
trajectron = Trajectron(model_registrar, hyperparams, None, 'cpu')
trajectron.set_environment(env)
trajectron.set_annealing_params()
return trajectron, hyperparams
if __name__ == "__main__":
with open(args.data, 'rb') as f:
env = dill.load(f, encoding='latin1')
eval_stg, hyperparams = load_model(args.model, env, ts=args.checkpoint)
if 'override_attention_radius' in hyperparams:
for attention_radius_override in hyperparams['override_attention_radius']:
node_type1, node_type2, attention_radius = attention_radius_override.split(' ')
env.attention_radius[(node_type1, node_type2)] = float(attention_radius)
scenes = env.scenes
print("-- Preparing Node Graph")
for scene in tqdm(scenes):
scene.calculate_scene_graph(env.attention_radius,
hyperparams['edge_addition_filter'],
hyperparams['edge_removal_filter'])
for ph in args.prediction_horizon:
print(f"Prediction Horizon: {ph}")
max_hl = hyperparams['maximum_history_length']
with torch.no_grad():
############### MOST LIKELY Z ###############
eval_ade_batch_errors = np.array([])
eval_fde_batch_errors = np.array([])
print("-- Evaluating GMM Z Mode (Most Likely)")
for scene in tqdm(scenes):
timesteps = np.arange(scene.timesteps)
predictions = eval_stg.predict(scene,
timesteps,
ph,
num_samples=1,
min_future_timesteps=8,
z_mode=True,
gmm_mode=True,
full_dist=False) # This will trigger grid sampling
batch_error_dict = evaluation.compute_batch_statistics(predictions,
scene.dt,
max_hl=max_hl,
ph=ph,
node_type_enum=env.NodeType,
map=None,
prune_ph_to_future=False,
kde=False)
eval_ade_batch_errors = np.hstack((eval_ade_batch_errors, batch_error_dict[args.node_type]['ade']))
eval_fde_batch_errors = np.hstack((eval_fde_batch_errors, batch_error_dict[args.node_type]['fde']))
print(np.mean(eval_fde_batch_errors))
pd.DataFrame({'value': eval_ade_batch_errors, 'metric': 'ade', 'type': 'ml'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_ade_most_likely_z.csv'))
pd.DataFrame({'value': eval_fde_batch_errors, 'metric': 'fde', 'type': 'ml'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_fde_most_likely_z.csv'))
############### FULL ###############
eval_ade_batch_errors = np.array([])
eval_fde_batch_errors = np.array([])
eval_kde_nll = np.array([])
eval_road_viols = np.array([])
print("-- Evaluating Full")
for scene in tqdm(scenes):
timesteps = np.arange(scene.timesteps)
predictions = eval_stg.predict(scene,
timesteps,
ph,
num_samples=2000,
min_future_timesteps=8,
z_mode=False,
gmm_mode=False,
full_dist=False)
if not predictions:
continue
prediction_dict, _, _ = utils.prediction_output_to_trajectories(predictions,
scene.dt,
max_hl,
ph,
prune_ph_to_future=False)
eval_road_viols_batch = []
for t in prediction_dict.keys():
for node in prediction_dict[t].keys():
if node.type == args.node_type:
viols = compute_road_violations(prediction_dict[t][node],
scene.map[args.node_type],
channel=0)
if viols == 2000:
viols = 0
eval_road_viols_batch.append(viols)
eval_road_viols = np.hstack((eval_road_viols, eval_road_viols_batch))
batch_error_dict = evaluation.compute_batch_statistics(predictions,
scene.dt,
max_hl=max_hl,
ph=ph,
node_type_enum=env.NodeType,
map=None,
prune_ph_to_future=False)
eval_ade_batch_errors = np.hstack((eval_ade_batch_errors, batch_error_dict[args.node_type]['ade']))
eval_fde_batch_errors = np.hstack((eval_fde_batch_errors, batch_error_dict[args.node_type]['fde']))
eval_kde_nll = np.hstack((eval_kde_nll, batch_error_dict[args.node_type]['kde']))
pd.DataFrame({'value': eval_ade_batch_errors, 'metric': 'ade', 'type': 'full'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_ade_full.csv'))
pd.DataFrame({'value': eval_fde_batch_errors, 'metric': 'fde', 'type': 'full'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_fde_full.csv'))
pd.DataFrame({'value': eval_kde_nll, 'metric': 'kde', 'type': 'full'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_kde_full.csv'))
pd.DataFrame({'value': eval_road_viols, 'metric': 'road_viols', 'type': 'full'}
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_rv_full.csv'))