185 lines
9.1 KiB
Python
185 lines
9.1 KiB
Python
import sys
|
|
import os
|
|
import dill
|
|
import json
|
|
import argparse
|
|
import torch
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
sys.path.append("../../trajectron")
|
|
from tqdm import tqdm
|
|
from model.model_registrar import ModelRegistrar
|
|
from model.trajectron import Trajectron
|
|
import evaluation
|
|
import utils
|
|
from scipy.interpolate import RectBivariateSpline
|
|
|
|
seed = 0
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model", help="model full path", type=str)
|
|
parser.add_argument("--checkpoint", help="model checkpoint to evaluate", type=int)
|
|
parser.add_argument("--data", help="full path to data file", type=str)
|
|
parser.add_argument("--output_path", help="path to output csv file", type=str)
|
|
parser.add_argument("--output_tag", help="name tag for output file", type=str)
|
|
parser.add_argument("--node_type", help="node type to evaluate", type=str)
|
|
parser.add_argument("--prediction_horizon", nargs='+', help="prediction horizon", type=int, default=None)
|
|
args = parser.parse_args()
|
|
|
|
|
|
def compute_road_violations(predicted_trajs, map, channel):
|
|
obs_map = 1 - map.data[..., channel, :, :] / 255
|
|
|
|
interp_obs_map = RectBivariateSpline(range(obs_map.shape[0]),
|
|
range(obs_map.shape[1]),
|
|
obs_map,
|
|
kx=1, ky=1)
|
|
|
|
old_shape = predicted_trajs.shape
|
|
pred_trajs_map = map.to_map_points(predicted_trajs.reshape((-1, 2)))
|
|
|
|
traj_obs_values = interp_obs_map(pred_trajs_map[:, 0], pred_trajs_map[:, 1], grid=False)
|
|
traj_obs_values = traj_obs_values.reshape((old_shape[0], old_shape[1], old_shape[2]))
|
|
num_viol_trajs = np.sum(traj_obs_values.max(axis=2) > 0, dtype=float)
|
|
|
|
return num_viol_trajs
|
|
|
|
|
|
def load_model(model_dir, env, ts=100):
|
|
model_registrar = ModelRegistrar(model_dir, 'cpu')
|
|
model_registrar.load_models(ts)
|
|
with open(os.path.join(model_dir, 'config.json'), 'r') as config_json:
|
|
hyperparams = json.load(config_json)
|
|
|
|
trajectron = Trajectron(model_registrar, hyperparams, None, 'cpu')
|
|
|
|
trajectron.set_environment(env)
|
|
trajectron.set_annealing_params()
|
|
return trajectron, hyperparams
|
|
|
|
|
|
if __name__ == "__main__":
|
|
with open(args.data, 'rb') as f:
|
|
env = dill.load(f, encoding='latin1')
|
|
|
|
eval_stg, hyperparams = load_model(args.model, env, ts=args.checkpoint)
|
|
|
|
if 'override_attention_radius' in hyperparams:
|
|
for attention_radius_override in hyperparams['override_attention_radius']:
|
|
node_type1, node_type2, attention_radius = attention_radius_override.split(' ')
|
|
env.attention_radius[(node_type1, node_type2)] = float(attention_radius)
|
|
|
|
scenes = env.scenes
|
|
|
|
print("-- Preparing Node Graph")
|
|
for scene in tqdm(scenes):
|
|
scene.calculate_scene_graph(env.attention_radius,
|
|
hyperparams['edge_addition_filter'],
|
|
hyperparams['edge_removal_filter'])
|
|
|
|
for ph in args.prediction_horizon:
|
|
print(f"Prediction Horizon: {ph}")
|
|
max_hl = hyperparams['maximum_history_length']
|
|
|
|
with torch.no_grad():
|
|
############### MOST LIKELY Z ###############
|
|
eval_ade_batch_errors = np.array([])
|
|
eval_fde_batch_errors = np.array([])
|
|
|
|
print("-- Evaluating GMM Z Mode (Most Likely)")
|
|
for scene in tqdm(scenes):
|
|
timesteps = np.arange(scene.timesteps)
|
|
|
|
predictions = eval_stg.predict(scene,
|
|
timesteps,
|
|
ph,
|
|
num_samples=1,
|
|
min_future_timesteps=8,
|
|
z_mode=True,
|
|
gmm_mode=True,
|
|
full_dist=False) # This will trigger grid sampling
|
|
|
|
batch_error_dict = evaluation.compute_batch_statistics(predictions,
|
|
scene.dt,
|
|
max_hl=max_hl,
|
|
ph=ph,
|
|
node_type_enum=env.NodeType,
|
|
map=None,
|
|
prune_ph_to_future=False,
|
|
kde=False)
|
|
|
|
eval_ade_batch_errors = np.hstack((eval_ade_batch_errors, batch_error_dict[args.node_type]['ade']))
|
|
eval_fde_batch_errors = np.hstack((eval_fde_batch_errors, batch_error_dict[args.node_type]['fde']))
|
|
|
|
print(np.mean(eval_fde_batch_errors))
|
|
pd.DataFrame({'value': eval_ade_batch_errors, 'metric': 'ade', 'type': 'ml'}
|
|
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_ade_most_likely_z.csv'))
|
|
pd.DataFrame({'value': eval_fde_batch_errors, 'metric': 'fde', 'type': 'ml'}
|
|
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_fde_most_likely_z.csv'))
|
|
|
|
|
|
############### FULL ###############
|
|
eval_ade_batch_errors = np.array([])
|
|
eval_fde_batch_errors = np.array([])
|
|
eval_kde_nll = np.array([])
|
|
eval_road_viols = np.array([])
|
|
print("-- Evaluating Full")
|
|
for scene in tqdm(scenes):
|
|
timesteps = np.arange(scene.timesteps)
|
|
predictions = eval_stg.predict(scene,
|
|
timesteps,
|
|
ph,
|
|
num_samples=2000,
|
|
min_future_timesteps=8,
|
|
z_mode=False,
|
|
gmm_mode=False,
|
|
full_dist=False)
|
|
|
|
if not predictions:
|
|
continue
|
|
|
|
prediction_dict, _, _ = utils.prediction_output_to_trajectories(predictions,
|
|
scene.dt,
|
|
max_hl,
|
|
ph,
|
|
prune_ph_to_future=False)
|
|
|
|
eval_road_viols_batch = []
|
|
for t in prediction_dict.keys():
|
|
for node in prediction_dict[t].keys():
|
|
if node.type == args.node_type:
|
|
viols = compute_road_violations(prediction_dict[t][node],
|
|
scene.map[args.node_type],
|
|
channel=0)
|
|
if viols == 2000:
|
|
viols = 0
|
|
|
|
eval_road_viols_batch.append(viols)
|
|
|
|
eval_road_viols = np.hstack((eval_road_viols, eval_road_viols_batch))
|
|
|
|
batch_error_dict = evaluation.compute_batch_statistics(predictions,
|
|
scene.dt,
|
|
max_hl=max_hl,
|
|
ph=ph,
|
|
node_type_enum=env.NodeType,
|
|
map=None,
|
|
prune_ph_to_future=False)
|
|
|
|
eval_ade_batch_errors = np.hstack((eval_ade_batch_errors, batch_error_dict[args.node_type]['ade']))
|
|
eval_fde_batch_errors = np.hstack((eval_fde_batch_errors, batch_error_dict[args.node_type]['fde']))
|
|
eval_kde_nll = np.hstack((eval_kde_nll, batch_error_dict[args.node_type]['kde']))
|
|
|
|
pd.DataFrame({'value': eval_ade_batch_errors, 'metric': 'ade', 'type': 'full'}
|
|
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_ade_full.csv'))
|
|
pd.DataFrame({'value': eval_fde_batch_errors, 'metric': 'fde', 'type': 'full'}
|
|
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_fde_full.csv'))
|
|
pd.DataFrame({'value': eval_kde_nll, 'metric': 'kde', 'type': 'full'}
|
|
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_kde_full.csv'))
|
|
pd.DataFrame({'value': eval_road_viols, 'metric': 'road_viols', 'type': 'full'}
|
|
).to_csv(os.path.join(args.output_path, args.output_tag + "_" + str(ph) + '_rv_full.csv'))
|