import sys import os import dill import json import argparse import torch import numpy as np import pandas as pd sys.path.append("../../trajectron") from tqdm import tqdm from model.model_registrar import ModelRegistrar from model.trajectron import Trajectron import evaluation import utils from scipy.interpolate import RectBivariateSpline seed = 0 np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) parser = argparse.ArgumentParser() parser.add_argument("--model", help="model full path", type=str) parser.add_argument("--checkpoint", help="model checkpoint to evaluate", type=int) parser.add_argument("--data", help="full path to data file", type=str) parser.add_argument("--output_path", help="path to output csv file", type=str) parser.add_argument("--output_tag", help="name tag for output file", type=str) parser.add_argument("--node_type", help="node type to evaluate", type=str) parser.add_argument("--prediction_horizon", nargs='+', help="prediction horizon", type=int, default=None) args = parser.parse_args() def compute_road_violations(predicted_trajs, map, channel): obs_map = 1 - map.data[..., channel, :, :] / 255 interp_obs_map = RectBivariateSpline(range(obs_map.shape[0]), range(obs_map.shape[1]), obs_map, kx=1, ky=1) old_shape = predicted_trajs.shape pred_trajs_map = map.to_map_points(predicted_trajs.reshape((-1, 2))) traj_obs_values = interp_obs_map(pred_trajs_map[:, 0], pred_trajs_map[:, 1], grid=False) traj_obs_values = traj_obs_values.reshape((old_shape[0], old_shape[1], old_shape[2])) num_viol_trajs = np.sum(traj_obs_values.max(axis=2) > 0, dtype=float) return num_viol_trajs def load_model(model_dir, env, ts=100): model_registrar = ModelRegistrar(model_dir, 'cpu') model_registrar.load_models(ts) with open(os.path.join(model_dir, 'config.json'), 'r') as config_json: hyperparams = json.load(config_json) trajectron = Trajectron(model_registrar, hyperparams, None, 'cpu') trajectron.set_environment(env) trajectron.set_annealing_params() return trajectron, hyperparams if __name__ == "__main__": with open(args.data, 'rb') as f: env = dill.load(f, encoding='latin1') eval_stg, hyperparams = load_model(args.model, env, ts=args.checkpoint) if 'override_attention_radius' in hyperparams: for attention_radius_override in hyperparams['override_attention_radius']: node_type1, node_type2, attention_radius = attention_radius_override.split(' ') env.attention_radius[(node_type1, node_type2)] = float(attention_radius) scenes = env.scenes print("-- Preparing Node Graph") for scene in tqdm(scenes): scene.calculate_scene_graph(env.attention_radius, hyperparams['edge_addition_filter'], hyperparams['edge_removal_filter']) for ph in args.prediction_horizon: print(f"Prediction Horizon: {ph}") max_hl = hyperparams['maximum_history_length'] with torch.no_grad(): ############### MOST LIKELY Z ############### eval_ade_batch_errors = np.array([]) eval_fde_batch_errors = np.array([]) print("-- Evaluating GMM Z Mode (Most Likely)") for scene in tqdm(scenes): timesteps = np.arange(scene.timesteps) predictions = eval_stg.predict(scene, timesteps, ph, num_samples=1, min_future_timesteps=8, z_mode=True, gmm_mode=True, full_dist=False) # This will trigger grid sampling batch_error_dict = evaluation.compute_batch_statistics(predictions, scene.dt, max_hl=max_hl, ph=ph, node_type_enum=env.NodeType, map=None, prune_ph_to_future=False, kde=False) eval_ade_batch_errors = np.hstack((eval_ade_batch_errors, batch_error_dict[args.node_type]['ade'])) eval_fde_batch_errors = np.hstack((eval_fde_batch_errors, batch_error_dict[args.node_type]['fde'])) print(np.mean(eval_fde_batch_errors)) pd.DataFrame({'value': eval_ade_batch_errors, 'metric': 'ade', 'type': 'ml'} ).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_ade_most_likely_z.csv')) pd.DataFrame({'value': eval_fde_batch_errors, 'metric': 'fde', 'type': 'ml'} ).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_fde_most_likely_z.csv')) ############### FULL ############### eval_ade_batch_errors = np.array([]) eval_fde_batch_errors = np.array([]) eval_kde_nll = np.array([]) eval_road_viols = np.array([]) print("-- Evaluating Full") for scene in tqdm(scenes): timesteps = np.arange(scene.timesteps) predictions = eval_stg.predict(scene, timesteps, ph, num_samples=2000, min_future_timesteps=8, z_mode=False, gmm_mode=False, full_dist=False) if not predictions: continue prediction_dict, _, _ = utils.prediction_output_to_trajectories(predictions, scene.dt, max_hl, ph, prune_ph_to_future=False) eval_road_viols_batch = [] for t in prediction_dict.keys(): for node in prediction_dict[t].keys(): if node.type == args.node_type: viols = compute_road_violations(prediction_dict[t][node], scene.map[args.node_type], channel=0) if viols == 2000: viols = 0 eval_road_viols_batch.append(viols) eval_road_viols = np.hstack((eval_road_viols, eval_road_viols_batch)) batch_error_dict = evaluation.compute_batch_statistics(predictions, scene.dt, max_hl=max_hl, ph=ph, node_type_enum=env.NodeType, map=None, prune_ph_to_future=False) eval_ade_batch_errors = np.hstack((eval_ade_batch_errors, batch_error_dict[args.node_type]['ade'])) eval_fde_batch_errors = np.hstack((eval_fde_batch_errors, batch_error_dict[args.node_type]['fde'])) eval_kde_nll = np.hstack((eval_kde_nll, batch_error_dict[args.node_type]['kde'])) pd.DataFrame({'value': eval_ade_batch_errors, 'metric': 'ade', 'type': 'full'} ).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_ade_full.csv')) pd.DataFrame({'value': eval_fde_batch_errors, 'metric': 'fde', 'type': 'full'} ).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_fde_full.csv')) pd.DataFrame({'value': eval_kde_nll, 'metric': 'kde', 'type': 'full'} ).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_kde_full.csv')) pd.DataFrame({'value': eval_road_viols, 'metric': 'road_viols', 'type': 'full'} ).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_rv_full.csv'))