2024-12-29 20:39:21 +01:00
import logging
from typing import List
2020-04-06 03:43:49 +02:00
import torch
from torch import nn , optim , utils
import numpy as np
import os
import time
import dill
import json
import random
import pathlib
import warnings
from tqdm import tqdm
2023-12-06 12:28:56 +01:00
import trajectron . visualization as visualization
import trajectron . evaluation as evaluation
2020-04-06 03:43:49 +02:00
import matplotlib . pyplot as plt
2023-12-06 12:28:56 +01:00
from trajectron . argument_parser import args
from trajectron . model . trajectron import Trajectron
from trajectron . model . model_registrar import ModelRegistrar
from trajectron . model . model_utils import cyclical_lr
from trajectron . model . dataset import EnvironmentDataset , collate
2024-12-29 20:39:21 +01:00
from trajectron . environment import Environment , Scene , Node
2020-04-06 03:43:49 +02:00
from tensorboardX import SummaryWriter
# torch.autograd.set_detect_anomaly(True)
if not torch . cuda . is_available ( ) or args . device == ' cpu ' :
args . device = torch . device ( ' cpu ' )
else :
if torch . cuda . device_count ( ) == 1 :
# If you have CUDA_VISIBLE_DEVICES set, which you should,
# then this will prevent leftover flag arguments from
# messing with the device allocation.
args . device = ' cuda:0 '
args . device = torch . device ( args . device )
if args . eval_device is None :
args . eval_device = torch . device ( ' cpu ' )
# This is needed for memory pinning using a DataLoader (otherwise memory is pinned to cuda:0 by default)
torch . cuda . set_device ( args . device )
if args . seed is not None :
random . seed ( args . seed )
np . random . seed ( args . seed )
torch . manual_seed ( args . seed )
if torch . cuda . is_available ( ) :
torch . cuda . manual_seed_all ( args . seed )
def main ( ) :
# Load hyperparameters from json
if not os . path . exists ( args . conf ) :
print ( ' Config json not found! ' )
with open ( args . conf , ' r ' , encoding = ' utf-8 ' ) as conf_json :
hyperparams = json . load ( conf_json )
# Add hyperparams from arguments
hyperparams [ ' dynamic_edges ' ] = args . dynamic_edges
hyperparams [ ' edge_state_combine_method ' ] = args . edge_state_combine_method
hyperparams [ ' edge_influence_combine_method ' ] = args . edge_influence_combine_method
hyperparams [ ' edge_addition_filter ' ] = args . edge_addition_filter
hyperparams [ ' edge_removal_filter ' ] = args . edge_removal_filter
hyperparams [ ' batch_size ' ] = args . batch_size
hyperparams [ ' k_eval ' ] = args . k_eval
hyperparams [ ' offline_scene_graph ' ] = args . offline_scene_graph
hyperparams [ ' incl_robot_node ' ] = args . incl_robot_node
hyperparams [ ' node_freq_mult_train ' ] = args . node_freq_mult_train
hyperparams [ ' node_freq_mult_eval ' ] = args . node_freq_mult_eval
hyperparams [ ' scene_freq_mult_train ' ] = args . scene_freq_mult_train
hyperparams [ ' scene_freq_mult_eval ' ] = args . scene_freq_mult_eval
hyperparams [ ' scene_freq_mult_viz ' ] = args . scene_freq_mult_viz
hyperparams [ ' edge_encoding ' ] = not args . no_edge_encoding
hyperparams [ ' use_map_encoding ' ] = args . map_encoding
hyperparams [ ' augment ' ] = args . augment
hyperparams [ ' override_attention_radius ' ] = args . override_attention_radius
print ( ' ----------------------- ' )
print ( ' | TRAINING PARAMETERS | ' )
print ( ' ----------------------- ' )
print ( ' | batch_size: %d ' % args . batch_size )
print ( ' | device: %s ' % args . device )
print ( ' | eval_device: %s ' % args . eval_device )
print ( ' | Offline Scene Graph Calculation: %s ' % args . offline_scene_graph )
print ( ' | EE state_combine_method: %s ' % args . edge_state_combine_method )
print ( ' | EIE scheme: %s ' % args . edge_influence_combine_method )
print ( ' | dynamic_edges: %s ' % args . dynamic_edges )
print ( ' | robot node: %s ' % args . incl_robot_node )
print ( ' | edge_addition_filter: %s ' % args . edge_addition_filter )
print ( ' | edge_removal_filter: %s ' % args . edge_removal_filter )
print ( ' | MHL: %s ' % hyperparams [ ' minimum_history_length ' ] )
print ( ' | PH: %s ' % hyperparams [ ' prediction_horizon ' ] )
print ( ' ----------------------- ' )
2024-12-13 10:38:12 +01:00
# TODO)) gets rid of torch/distributions/distribution.py:44: UserWarning: <class 'trajectron.model.components.gmm2d.GMM2D'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.
warnings . filterwarnings ( " ignore " )
2020-04-06 03:43:49 +02:00
log_writer = None
model_dir = None
if not args . debug :
# Create the log and model directiory if they're not present.
model_dir = os . path . join ( args . log_dir ,
2024-12-13 10:38:12 +01:00
' models_ ' + time . strftime ( ' % Y % m %d _ % H_ % M_ % S ' , time . localtime ( ) ) + args . log_tag )
2020-04-06 03:43:49 +02:00
pathlib . Path ( model_dir ) . mkdir ( parents = True , exist_ok = True )
# Save config to model directory
with open ( os . path . join ( model_dir , ' config.json ' ) , ' w ' ) as conf_json :
json . dump ( hyperparams , conf_json )
log_writer = SummaryWriter ( log_dir = model_dir )
# Load training and evaluation environments and scenes
train_scenes = [ ]
train_data_path = os . path . join ( args . data_dir , args . train_data_dict )
with open ( train_data_path , ' rb ' ) as f :
train_env = dill . load ( f , encoding = ' latin1 ' )
for attention_radius_override in args . override_attention_radius :
node_type1 , node_type2 , attention_radius = attention_radius_override . split ( ' ' )
train_env . attention_radius [ ( node_type1 , node_type2 ) ] = float ( attention_radius )
if train_env . robot_type is None and hyperparams [ ' incl_robot_node ' ] :
train_env . robot_type = train_env . NodeType [ 0 ] # TODO: Make more general, allow the user to specify?
for scene in train_env . scenes :
scene . add_robot_from_nodes ( train_env . robot_type )
train_scenes = train_env . scenes
train_scenes_sample_probs = train_env . scenes_freq_mult_prop if args . scene_freq_mult_train else None
train_dataset = EnvironmentDataset ( train_env ,
hyperparams [ ' state ' ] ,
hyperparams [ ' pred_state ' ] ,
scene_freq_mult = hyperparams [ ' scene_freq_mult_train ' ] ,
node_freq_mult = hyperparams [ ' node_freq_mult_train ' ] ,
hyperparams = hyperparams ,
min_history_timesteps = hyperparams [ ' minimum_history_length ' ] ,
min_future_timesteps = hyperparams [ ' prediction_horizon ' ] ,
return_robot = not args . incl_robot_node )
train_data_loader = dict ( )
2024-12-29 20:39:21 +01:00
logging . debug ( f " { train_scenes =} " )
2020-04-06 03:43:49 +02:00
for node_type_data_set in train_dataset :
2020-12-10 04:42:06 +01:00
if len ( node_type_data_set ) == 0 :
continue
2020-04-06 03:43:49 +02:00
node_type_dataloader = utils . data . DataLoader ( node_type_data_set ,
collate_fn = collate ,
2024-12-13 10:38:12 +01:00
pin_memory = False if args . device == ' cpu ' else True ,
2020-04-06 03:43:49 +02:00
batch_size = args . batch_size ,
shuffle = True ,
num_workers = args . preprocess_workers )
train_data_loader [ node_type_data_set . node_type ] = node_type_dataloader
print ( f " Loaded training data from { train_data_path } " )
eval_scenes = [ ]
eval_scenes_sample_probs = None
if args . eval_every is not None :
eval_data_path = os . path . join ( args . data_dir , args . eval_data_dict )
with open ( eval_data_path , ' rb ' ) as f :
eval_env = dill . load ( f , encoding = ' latin1 ' )
for attention_radius_override in args . override_attention_radius :
node_type1 , node_type2 , attention_radius = attention_radius_override . split ( ' ' )
eval_env . attention_radius [ ( node_type1 , node_type2 ) ] = float ( attention_radius )
if eval_env . robot_type is None and hyperparams [ ' incl_robot_node ' ] :
eval_env . robot_type = eval_env . NodeType [ 0 ] # TODO: Make more general, allow the user to specify?
for scene in eval_env . scenes :
scene . add_robot_from_nodes ( eval_env . robot_type )
2024-12-29 20:39:21 +01:00
eval_scenes : List [ Scene ] = eval_env . scenes
2020-04-06 03:43:49 +02:00
eval_scenes_sample_probs = eval_env . scenes_freq_mult_prop if args . scene_freq_mult_eval else None
eval_dataset = EnvironmentDataset ( eval_env ,
hyperparams [ ' state ' ] ,
hyperparams [ ' pred_state ' ] ,
scene_freq_mult = hyperparams [ ' scene_freq_mult_eval ' ] ,
node_freq_mult = hyperparams [ ' node_freq_mult_eval ' ] ,
hyperparams = hyperparams ,
min_history_timesteps = hyperparams [ ' minimum_history_length ' ] ,
min_future_timesteps = hyperparams [ ' prediction_horizon ' ] ,
return_robot = not args . incl_robot_node )
eval_data_loader = dict ( )
2024-12-29 20:39:21 +01:00
logging . debug ( f " { eval_scenes =} " )
2020-04-06 03:43:49 +02:00
for node_type_data_set in eval_dataset :
2020-12-10 04:42:06 +01:00
if len ( node_type_data_set ) == 0 :
continue
2020-04-06 03:43:49 +02:00
node_type_dataloader = utils . data . DataLoader ( node_type_data_set ,
collate_fn = collate ,
2024-12-13 10:38:12 +01:00
pin_memory = False if args . eval_device == ' cpu ' else True ,
2020-04-06 03:43:49 +02:00
batch_size = args . eval_batch_size ,
shuffle = True ,
num_workers = args . preprocess_workers )
eval_data_loader [ node_type_data_set . node_type ] = node_type_dataloader
print ( f " Loaded evaluation data from { eval_data_path } " )
# Offline Calculate Scene Graph
if hyperparams [ ' offline_scene_graph ' ] == ' yes ' :
print ( f " Offline calculating scene graphs " )
for i , scene in enumerate ( train_scenes ) :
scene . calculate_scene_graph ( train_env . attention_radius ,
hyperparams [ ' edge_addition_filter ' ] ,
hyperparams [ ' edge_removal_filter ' ] )
print ( f " Created Scene Graph for Training Scene { i } " )
for i , scene in enumerate ( eval_scenes ) :
scene . calculate_scene_graph ( eval_env . attention_radius ,
hyperparams [ ' edge_addition_filter ' ] ,
hyperparams [ ' edge_removal_filter ' ] )
print ( f " Created Scene Graph for Evaluation Scene { i } " )
model_registrar = ModelRegistrar ( model_dir , args . device )
trajectron = Trajectron ( model_registrar ,
hyperparams ,
log_writer ,
args . device )
trajectron . set_environment ( train_env )
trajectron . set_annealing_params ( )
print ( ' Created Training Model. ' )
eval_trajectron = None
if args . eval_every is not None or args . vis_every is not None :
eval_trajectron = Trajectron ( model_registrar ,
hyperparams ,
log_writer ,
args . eval_device )
eval_trajectron . set_environment ( eval_env )
eval_trajectron . set_annealing_params ( )
print ( ' Created Evaluation Model. ' )
optimizer = dict ( )
lr_scheduler = dict ( )
for node_type in train_env . NodeType :
if node_type not in hyperparams [ ' pred_state ' ] :
continue
optimizer [ node_type ] = optim . Adam ( [ { ' params ' : model_registrar . get_all_but_name_match ( ' map_encoder ' ) . parameters ( ) } ,
{ ' params ' : model_registrar . get_name_match ( ' map_encoder ' ) . parameters ( ) , ' lr ' : 0.0008 } ] , lr = hyperparams [ ' learning_rate ' ] )
# Set Learning Rate
if hyperparams [ ' learning_rate_style ' ] == ' const ' :
lr_scheduler [ node_type ] = optim . lr_scheduler . ExponentialLR ( optimizer [ node_type ] , gamma = 1.0 )
elif hyperparams [ ' learning_rate_style ' ] == ' exp ' :
lr_scheduler [ node_type ] = optim . lr_scheduler . ExponentialLR ( optimizer [ node_type ] ,
gamma = hyperparams [ ' learning_decay_rate ' ] )
#################################
# TRAINING #
#################################
curr_iter_node_type = { node_type : 0 for node_type in train_data_loader . keys ( ) }
for epoch in range ( 1 , args . train_epochs + 1 ) :
model_registrar . to ( args . device )
train_dataset . augment = args . augment
2024-12-13 10:38:12 +01:00
# print('train', curr_iter_node_type)
2020-04-06 03:43:49 +02:00
for node_type , data_loader in train_data_loader . items ( ) :
curr_iter = curr_iter_node_type [ node_type ]
pbar = tqdm ( data_loader , ncols = 80 )
for batch in pbar :
trajectron . set_curr_iter ( curr_iter )
trajectron . step_annealers ( node_type )
optimizer [ node_type ] . zero_grad ( )
train_loss = trajectron . train_loss ( batch , node_type )
pbar . set_description ( f " Epoch { epoch } , { node_type } L: { train_loss . item ( ) : .2f } " )
train_loss . backward ( )
# Clipping gradients.
if hyperparams [ ' grad_clip ' ] is not None :
nn . utils . clip_grad_value_ ( model_registrar . parameters ( ) , hyperparams [ ' grad_clip ' ] )
optimizer [ node_type ] . step ( )
# Stepping forward the learning rate scheduler and annealers.
lr_scheduler [ node_type ] . step ( )
if not args . debug :
log_writer . add_scalar ( f " { node_type } /train/learning_rate " ,
lr_scheduler [ node_type ] . get_lr ( ) [ 0 ] ,
curr_iter )
log_writer . add_scalar ( f " { node_type } /train/loss " , train_loss , curr_iter )
curr_iter + = 1
curr_iter_node_type [ node_type ] = curr_iter
train_dataset . augment = False
if args . eval_every is not None or args . vis_every is not None :
eval_trajectron . set_curr_iter ( epoch )
#################################
# VISUALIZATION #
#################################
if args . vis_every is not None and not args . debug and epoch % args . vis_every == 0 and epoch > 0 :
max_hl = hyperparams [ ' maximum_history_length ' ]
ph = hyperparams [ ' prediction_horizon ' ]
with torch . no_grad ( ) :
# Predict random timestep to plot for train data set
if args . scene_freq_mult_viz :
scene = np . random . choice ( train_scenes , p = train_scenes_sample_probs )
else :
scene = np . random . choice ( train_scenes )
timestep = scene . sample_timesteps ( 1 , min_future_timesteps = ph )
predictions = trajectron . predict ( scene ,
timestep ,
ph ,
2020-12-10 04:42:06 +01:00
min_future_timesteps = ph ,
2020-04-06 03:43:49 +02:00
z_mode = True ,
gmm_mode = True ,
all_z_sep = False ,
full_dist = False )
# Plot predicted timestep for random scene
fig , ax = plt . subplots ( figsize = ( 10 , 10 ) )
visualization . visualize_prediction ( ax ,
predictions ,
scene . dt ,
max_hl = max_hl ,
ph = ph ,
map = scene . map [ ' VISUALIZATION ' ] if scene . map is not None else None )
ax . set_title ( f " { scene . name } -t: { timestep } " )
log_writer . add_figure ( ' train/prediction ' , fig , epoch )
model_registrar . to ( args . eval_device )
# Predict random timestep to plot for eval data set
if args . scene_freq_mult_viz :
scene = np . random . choice ( eval_scenes , p = eval_scenes_sample_probs )
else :
scene = np . random . choice ( eval_scenes )
timestep = scene . sample_timesteps ( 1 , min_future_timesteps = ph )
predictions = eval_trajectron . predict ( scene ,
timestep ,
ph ,
num_samples = 20 ,
min_future_timesteps = ph ,
z_mode = False ,
full_dist = False )
# Plot predicted timestep for random scene
fig , ax = plt . subplots ( figsize = ( 10 , 10 ) )
visualization . visualize_prediction ( ax ,
predictions ,
scene . dt ,
max_hl = max_hl ,
ph = ph ,
map = scene . map [ ' VISUALIZATION ' ] if scene . map is not None else None )
ax . set_title ( f " { scene . name } -t: { timestep } " )
log_writer . add_figure ( ' eval/prediction ' , fig , epoch )
# Predict random timestep to plot for eval data set
predictions = eval_trajectron . predict ( scene ,
timestep ,
ph ,
min_future_timesteps = ph ,
z_mode = True ,
gmm_mode = True ,
all_z_sep = True ,
full_dist = False )
# Plot predicted timestep for random scene
fig , ax = plt . subplots ( figsize = ( 10 , 10 ) )
visualization . visualize_prediction ( ax ,
predictions ,
scene . dt ,
max_hl = max_hl ,
ph = ph ,
map = scene . map [ ' VISUALIZATION ' ] if scene . map is not None else None )
ax . set_title ( f " { scene . name } -t: { timestep } " )
log_writer . add_figure ( ' eval/prediction_all_z ' , fig , epoch )
#################################
# EVALUATION #
#################################
if args . eval_every is not None and not args . debug and epoch % args . eval_every == 0 and epoch > 0 :
max_hl = hyperparams [ ' maximum_history_length ' ]
ph = hyperparams [ ' prediction_horizon ' ]
model_registrar . to ( args . eval_device )
with torch . no_grad ( ) :
# Calculate evaluation loss
for node_type , data_loader in eval_data_loader . items ( ) :
eval_loss = [ ]
print ( f " Starting Evaluation @ epoch { epoch } for node type: { node_type } " )
pbar = tqdm ( data_loader , ncols = 80 )
for batch in pbar :
eval_loss_node_type = eval_trajectron . eval_loss ( batch , node_type )
pbar . set_description ( f " Epoch { epoch } , { node_type } L: { eval_loss_node_type . item ( ) : .2f } " )
eval_loss . append ( { node_type : { ' nll ' : [ eval_loss_node_type ] } } )
del batch
evaluation . log_batch_errors ( eval_loss ,
log_writer ,
f " { node_type } /eval_loss " ,
epoch )
# Predict batch timesteps for evaluation dataset evaluation
eval_batch_errors = [ ]
for scene in tqdm ( eval_scenes , desc = ' Sample Evaluation ' , ncols = 80 ) :
2024-12-29 20:39:21 +01:00
logging . debug ( f " { scene } , { scene . timesteps =} , { len ( scene . nodes ) } " )
2020-04-06 03:43:49 +02:00
timesteps = scene . sample_timesteps ( args . eval_batch_size )
predictions = eval_trajectron . predict ( scene ,
timesteps ,
ph ,
num_samples = 50 ,
min_future_timesteps = ph ,
full_dist = False )
eval_batch_errors . append ( evaluation . compute_batch_statistics ( predictions ,
scene . dt ,
max_hl = max_hl ,
ph = ph ,
node_type_enum = eval_env . NodeType ,
map = scene . map ) )
evaluation . log_batch_errors ( eval_batch_errors ,
log_writer ,
' eval ' ,
epoch ,
bar_plot = [ ' kde ' ] ,
box_plot = [ ' ade ' , ' fde ' ] )
# Predict maximum likelihood batch timesteps for evaluation dataset evaluation
eval_batch_errors_ml = [ ]
for scene in tqdm ( eval_scenes , desc = ' MM Evaluation ' , ncols = 80 ) :
2024-12-29 20:39:21 +01:00
logging . debug ( f " { scene } , { scene . timesteps =} , { len ( scene . nodes ) } " )
2020-04-06 03:43:49 +02:00
timesteps = scene . sample_timesteps ( scene . timesteps )
predictions = eval_trajectron . predict ( scene ,
timesteps ,
ph ,
num_samples = 1 ,
min_future_timesteps = ph ,
z_mode = True ,
gmm_mode = True ,
full_dist = False )
eval_batch_errors_ml . append ( evaluation . compute_batch_statistics ( predictions ,
scene . dt ,
max_hl = max_hl ,
ph = ph ,
map = scene . map ,
node_type_enum = eval_env . NodeType ,
kde = False ) )
evaluation . log_batch_errors ( eval_batch_errors_ml ,
log_writer ,
' eval/ml ' ,
epoch )
if args . save_every is not None and args . debug is False and epoch % args . save_every == 0 :
model_registrar . save_models ( epoch )
if __name__ == ' __main__ ' :
main ( )