477 lines
24 KiB
Python
477 lines
24 KiB
Python
import torch
|
|
from torch import nn, optim
|
|
import numpy as np
|
|
import os
|
|
import time
|
|
import psutil
|
|
import pickle
|
|
import json
|
|
import random
|
|
import argparse
|
|
import pathlib
|
|
import visualization
|
|
import evaluation
|
|
import matplotlib.pyplot as plt
|
|
from model.dyn_stg import SpatioTemporalGraphCVAEModel
|
|
from model.model_registrar import ModelRegistrar
|
|
from model.model_utils import cyclical_lr
|
|
from tensorboardX import SummaryWriter
|
|
#torch.autograd.set_detect_anomaly(True) # TODO Remove for speed
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--conf", help="path to json config file for hyperparameters",
|
|
type=str, default='config.json')
|
|
parser.add_argument("--offline_scene_graph", help="whether to precompute the scene graphs offline, options are 'no' and 'yes'",
|
|
type=str, default='yes')
|
|
parser.add_argument("--dynamic_edges", help="whether to use dynamic edges or not, options are 'no' and 'yes'",
|
|
type=str, default='yes')
|
|
parser.add_argument("--edge_radius", help="the radius (in meters) within which two nodes will be connected by an edge",
|
|
type=float, default=3.0)
|
|
parser.add_argument("--edge_state_combine_method", help="the method to use for combining edges of the same type",
|
|
type=str, default='sum')
|
|
parser.add_argument("--edge_influence_combine_method", help="the method to use for combining edge influences",
|
|
type=str, default='attention')
|
|
parser.add_argument('--edge_addition_filter', nargs='+', help="what scaling to use for edges as they're created",
|
|
type=float, default=[0.25, 0.5, 0.75, 1.0]) # We automatically pad left with 0.0
|
|
parser.add_argument('--edge_removal_filter', nargs='+', help="what scaling to use for edges as they're removed",
|
|
type=float, default=[1.0, 0.0]) # We automatically pad right with 0.0
|
|
parser.add_argument('--incl_robot_node', help="whether to include a robot node in the graph or simply model all agents",
|
|
action='store_true')
|
|
parser.add_argument('--use_map_encoding', help="Whether to use map encoding or not",
|
|
action='store_true')
|
|
|
|
parser.add_argument("--data_dir", help="what dir to look in for data",
|
|
type=str, default='../data/processed')
|
|
parser.add_argument("--train_data_dict", help="what file to load for training data",
|
|
type=str, default='nuScenes_train.pkl')
|
|
parser.add_argument("--eval_data_dict", help="what file to load for evaluation data",
|
|
type=str, default='nuScenes_val.pkl')
|
|
parser.add_argument("--log_dir", help="what dir to save training information (i.e., saved models, logs, etc)",
|
|
type=str, default='../data/nuScenes/logs')
|
|
parser.add_argument("--log_tag", help="tag for the log folder",
|
|
type=str, default='')
|
|
|
|
parser.add_argument('--device', help='what device to perform training on',
|
|
type=str, default='cuda:1')
|
|
parser.add_argument("--eval_device", help="what device to use during evaluation",
|
|
type=str, default=None)
|
|
|
|
parser.add_argument("--num_iters", help="number of iterations to train for",
|
|
type=int, default=2000)
|
|
parser.add_argument('--batch_multiplier', help='how many minibatches to run per iteration of training',
|
|
type=int, default=1)
|
|
parser.add_argument('--batch_size', help='training batch size',
|
|
type=int, default=256)
|
|
parser.add_argument('--eval_batch_size', help='evaluation batch size',
|
|
type=int, default=256)
|
|
parser.add_argument('--k_eval', help='how many samples to take during evaluation',
|
|
type=int, default=50)
|
|
|
|
parser.add_argument('--seed', help='manual seed to use, default is 123',
|
|
type=int, default=123)
|
|
parser.add_argument('--eval_every', help='how often to evaluate during training, never if None',
|
|
type=int, default=50)
|
|
parser.add_argument('--vis_every', help='how often to visualize during training, never if None',
|
|
type=int, default=50)
|
|
parser.add_argument('--save_every', help='how often to save during training, never if None',
|
|
type=int, default=100)
|
|
args = parser.parse_args()
|
|
|
|
if not torch.cuda.is_available() or args.device == 'cpu':
|
|
args.device = torch.device('cpu')
|
|
else:
|
|
if torch.cuda.device_count() == 1:
|
|
# If you have CUDA_VISIBLE_DEVICES set, which you should,
|
|
# then this will prevent leftover flag arguments from
|
|
# messing with the device allocation.
|
|
args.device = 'cuda:0'
|
|
|
|
args.device = torch.device(args.device)
|
|
|
|
if args.eval_device is None:
|
|
args.eval_device = 'cpu'
|
|
|
|
if args.seed is not None:
|
|
random.seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
|
|
def main():
|
|
# Load hyperparameters from json
|
|
if not os.path.exists(args.conf):
|
|
print('Config json not found!')
|
|
with open(args.conf, 'r') as conf_json:
|
|
hyperparams = json.load(conf_json)
|
|
|
|
# Add hyperparams from arguments
|
|
hyperparams['dynamic_edges'] = args.dynamic_edges
|
|
hyperparams['edge_state_combine_method'] = args.edge_state_combine_method
|
|
hyperparams['edge_influence_combine_method'] = args.edge_influence_combine_method
|
|
hyperparams['edge_radius'] = args.edge_radius
|
|
hyperparams['use_map_encoding'] = args.use_map_encoding
|
|
hyperparams['edge_addition_filter'] = args.edge_addition_filter
|
|
hyperparams['edge_removal_filter'] = args.edge_removal_filter
|
|
hyperparams['batch_size'] = args.batch_size
|
|
hyperparams['k_eval'] = args.k_eval
|
|
hyperparams['offline_scene_graph'] = args.offline_scene_graph
|
|
hyperparams['incl_robot_node'] = args.incl_robot_node
|
|
|
|
print('-----------------------')
|
|
print('| TRAINING PARAMETERS |')
|
|
print('-----------------------')
|
|
print('| iterations: %d' % args.num_iters)
|
|
print('| batch_size: %d' % args.batch_size)
|
|
print('| batch_multiplier: %d' % args.batch_multiplier)
|
|
print('| effective batch size: %d (= %d * %d)' % (args.batch_size * args.batch_multiplier, args.batch_size, args.batch_multiplier))
|
|
print('| device: %s' % args.device)
|
|
print('| eval_device: %s' % args.eval_device)
|
|
print('| Offline Scene Graph Calculation: %s' % args.offline_scene_graph)
|
|
print('| edge_radius: %s' % args.edge_radius)
|
|
print('| EE state_combine_method: %s' % args.edge_state_combine_method)
|
|
print('| EIE scheme: %s' % args.edge_influence_combine_method)
|
|
print('| dynamic_edges: %s' % args.dynamic_edges)
|
|
print('| robot node: %s' % args.incl_robot_node)
|
|
print('| map encoding: %s' % args.use_map_encoding)
|
|
print('| edge_addition_filter: %s' % args.edge_addition_filter)
|
|
print('| edge_removal_filter: %s' % args.edge_removal_filter)
|
|
print('| MHL: %s' % hyperparams['minimum_history_length'])
|
|
print('| PH: %s' % hyperparams['prediction_horizon'])
|
|
print('-----------------------')
|
|
|
|
# Create the log and model directiory if they're not present.
|
|
model_dir = os.path.join(args.log_dir, 'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) + args.log_tag)
|
|
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save config to model directory
|
|
with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json:
|
|
json.dump(hyperparams, conf_json)
|
|
|
|
log_writer = SummaryWriter(log_dir=model_dir)
|
|
|
|
train_scenes = []
|
|
train_data_path = os.path.join(args.data_dir, args.train_data_dict)
|
|
with open(train_data_path, 'rb') as f:
|
|
train_env = pickle.load(f, encoding='latin1')
|
|
train_scenes = train_env.scenes
|
|
print('Loaded training data from %s' % (train_data_path,))
|
|
|
|
eval_scenes = []
|
|
if args.eval_every is not None:
|
|
eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
|
|
with open(eval_data_path, 'rb') as f:
|
|
eval_env = pickle.load(f, encoding='latin1')
|
|
eval_scenes = eval_env.scenes
|
|
print('Loaded evaluation data from %s' % (eval_data_path, ))
|
|
|
|
# Calculate Scene Graph
|
|
if hyperparams['offline_scene_graph'] == 'yes':
|
|
print(f"Offline calculating scene graphs")
|
|
for i, scene in enumerate(train_scenes):
|
|
scene.calculate_scene_graph(train_env.attention_radius,
|
|
hyperparams['state'],
|
|
hyperparams['edge_addition_filter'],
|
|
hyperparams['edge_removal_filter'])
|
|
print(f"Created Scene Graph for Scene {i}")
|
|
|
|
for i, scene in enumerate(eval_scenes):
|
|
scene.calculate_scene_graph(eval_env.attention_radius,
|
|
hyperparams['state'],
|
|
hyperparams['edge_addition_filter'],
|
|
hyperparams['edge_removal_filter'])
|
|
print(f"Created Scene Graph for Scene {i}")
|
|
|
|
model_registrar = ModelRegistrar(model_dir, args.device)
|
|
|
|
# We use pre trained weights for the map CNN
|
|
if args.use_map_encoding:
|
|
inf_encoder_registrar = os.path.join(args.log_dir, 'weight_trans/model_registrar-1499.pt')
|
|
model_dict = torch.load(inf_encoder_registrar, map_location=args.device)
|
|
|
|
for key in model_dict.keys():
|
|
if 'map_encoder' in key:
|
|
model_registrar.model_dict[key] = model_dict[key]
|
|
assert model_registrar.get_model(key) is model_dict[key]
|
|
|
|
stg = SpatioTemporalGraphCVAEModel(model_registrar,
|
|
hyperparams,
|
|
log_writer, args.device)
|
|
stg.set_scene_graph(train_env)
|
|
stg.set_annealing_params()
|
|
print('Created training STG model.')
|
|
|
|
eval_stg = None
|
|
if args.eval_every is not None or args.vis_ervery is not None:
|
|
eval_stg = SpatioTemporalGraphCVAEModel(model_registrar,
|
|
hyperparams,
|
|
log_writer, args.device)
|
|
eval_stg.set_scene_graph(eval_env)
|
|
eval_stg.set_annealing_params() # TODO Check if necessary
|
|
if hyperparams['learning_rate_style'] == 'const':
|
|
optimizer = optim.Adam(model_registrar.parameters(), lr=hyperparams['learning_rate'])
|
|
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0)
|
|
elif hyperparams['learning_rate_style'] == 'exp':
|
|
optimizer = optim.Adam(model_registrar.parameters(), lr=hyperparams['learning_rate'])
|
|
lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=hyperparams['learning_decay_rate'])
|
|
elif hyperparams['learning_rate_style'] == 'triangle':
|
|
optimizer = optim.Adam(model_registrar.parameters(), lr=1.0)
|
|
clr = cyclical_lr(100, min_lr=hyperparams['min_learning_rate'], max_lr=hyperparams['learning_rate'], decay=hyperparams['learning_decay_rate'])
|
|
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, [clr])
|
|
|
|
print_training_header(newline_start=True)
|
|
for curr_iter in range(args.num_iters):
|
|
# Necessary because we flip the weights contained between GPU and CPU sometimes.
|
|
model_registrar.to(args.device)
|
|
|
|
# Setting the current iterator value for internal logging.
|
|
stg.set_curr_iter(curr_iter)
|
|
if args.vis_every is not None:
|
|
eval_stg.set_curr_iter(curr_iter)
|
|
|
|
# Stepping forward the learning rate scheduler and annealers.
|
|
lr_scheduler.step()
|
|
log_writer.add_scalar('train/learning_rate',
|
|
lr_scheduler.get_lr()[0],
|
|
curr_iter)
|
|
stg.step_annealers()
|
|
|
|
# Zeroing gradients for the upcoming iteration.
|
|
optimizer.zero_grad()
|
|
train_losses = dict()
|
|
for node_type in train_env.NodeType:
|
|
train_losses[node_type] = []
|
|
for scene in np.random.choice(train_scenes, 10):
|
|
for mb_num in range(args.batch_multiplier):
|
|
# Obtaining the batch's training loss.
|
|
timesteps = scene.sample_timesteps(hyperparams['batch_size'])
|
|
|
|
# Compute the training loss.
|
|
train_loss_by_type = stg.train_loss(scene, timesteps, max_nodes=hyperparams['batch_size'])
|
|
for node_type, train_loss in train_loss_by_type.items():
|
|
if train_loss is not None:
|
|
train_loss = train_loss / (args.batch_multiplier * 10)
|
|
train_losses[node_type].append(train_loss.item())
|
|
|
|
# Calculating gradients.
|
|
train_loss.backward()
|
|
|
|
# Print training information. Also, no newline here. It's added in at a later line.
|
|
print('{:9} | '.format(curr_iter), end='', flush=True)
|
|
for node_type in train_env.NodeType:
|
|
print('{}:{:10} | '.format(node_type.name[0], '%.2f' % sum(train_losses[node_type])), end='', flush=True)
|
|
|
|
for node_type in train_env.NodeType:
|
|
if len(train_losses[node_type]) > 0:
|
|
log_writer.add_histogram(f"{node_type.name}/train/minibatch_losses", np.asarray(train_losses[node_type]), curr_iter)
|
|
log_writer.add_scalar(f"{node_type.name}/train/loss", sum(train_losses[node_type]), curr_iter)
|
|
|
|
# Clipping gradients.
|
|
if hyperparams['grad_clip'] is not None:
|
|
nn.utils.clip_grad_value_(model_registrar.parameters(), hyperparams['grad_clip'])
|
|
|
|
# Performing a gradient step.
|
|
optimizer.step()
|
|
|
|
del train_loss # TODO Necessary?
|
|
|
|
if args.vis_every is not None and (curr_iter + 1) % args.vis_every == 0:
|
|
max_hl = hyperparams['maximum_history_length']
|
|
ph = hyperparams['prediction_horizon']
|
|
with torch.no_grad():
|
|
# Predict random timestep to plot for train data set
|
|
scene = np.random.choice(train_scenes)
|
|
timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
|
|
predictions = stg.predict(scene,
|
|
timestep,
|
|
ph,
|
|
num_samples_z=100,
|
|
most_likely_z=False,
|
|
all_z=False)
|
|
|
|
# Plot predicted timestep for random scene
|
|
fig, ax = plt.subplots(figsize=(5, 5))
|
|
visualization.visualize_prediction(ax,
|
|
predictions,
|
|
scene.dt,
|
|
max_hl=max_hl,
|
|
ph=ph)
|
|
ax.set_title(f"{scene.name}-t: {timestep}")
|
|
log_writer.add_figure('train/prediction', fig, curr_iter)
|
|
|
|
# Predict random timestep to plot for eval data set
|
|
scene = np.random.choice(eval_scenes)
|
|
timestep = scene.sample_timesteps(1, min_future_timesteps=ph)
|
|
predictions = eval_stg.predict(scene,
|
|
timestep,
|
|
ph,
|
|
num_samples_z=100,
|
|
most_likely_z=False,
|
|
all_z=False,
|
|
max_nodes=4 * args.eval_batch_size)
|
|
|
|
# Plot predicted timestep for random scene
|
|
fig, ax = plt.subplots(figsize=(5, 5))
|
|
visualization.visualize_prediction(ax,
|
|
predictions,
|
|
scene.dt,
|
|
max_hl=max_hl,
|
|
ph=ph)
|
|
ax.set_title(f"{scene.name}-t: {timestep}")
|
|
log_writer.add_figure('eval/prediction', fig, curr_iter)
|
|
|
|
# Plot predicted timestep for random scene in map
|
|
fig, ax = plt.subplots(figsize=(15, 15))
|
|
visualization.visualize_prediction(ax,
|
|
predictions,
|
|
scene.dt,
|
|
max_hl=max_hl,
|
|
ph=ph,
|
|
map=scene.map['PLOT'])
|
|
ax.set_title(f"{scene.name}-t: {timestep}")
|
|
log_writer.add_figure('eval/prediction_map', fig, curr_iter)
|
|
|
|
# Predict random timestep to plot for eval data set
|
|
predictions = eval_stg.predict(scene,
|
|
timestep,
|
|
ph,
|
|
num_samples_gmm=50,
|
|
most_likely_z=False,
|
|
all_z=True,
|
|
max_nodes=4 * args.eval_batch_size)
|
|
|
|
# Plot predicted timestep for random scene
|
|
fig, ax = plt.subplots(figsize=(5, 5))
|
|
visualization.visualize_prediction(ax,
|
|
predictions,
|
|
scene.dt,
|
|
max_hl=max_hl,
|
|
ph=ph)
|
|
ax.set_title(f"{scene.name}-t: {timestep}")
|
|
log_writer.add_figure('eval/prediction_all_z', fig, curr_iter)
|
|
|
|
if args.eval_every is not None and (curr_iter + 1) % args.eval_every == 0:
|
|
max_hl = hyperparams['maximum_history_length']
|
|
ph = hyperparams['prediction_horizon']
|
|
with torch.no_grad():
|
|
# Predict batch timesteps for training dataset evaluation
|
|
train_batch_errors = []
|
|
max_scenes = np.min([len(train_scenes), 5])
|
|
for scene in np.random.choice(train_scenes, max_scenes):
|
|
timesteps = scene.sample_timesteps(args.eval_batch_size)
|
|
predictions = stg.predict(scene,
|
|
timesteps,
|
|
ph,
|
|
num_samples_z=100,
|
|
min_future_timesteps=ph,
|
|
max_nodes=4*args.eval_batch_size)
|
|
|
|
train_batch_errors.append(evaluation.compute_batch_statistics(predictions,
|
|
scene.dt,
|
|
max_hl=max_hl,
|
|
ph=ph,
|
|
node_type_enum=train_env.NodeType,
|
|
map=scene.map))
|
|
|
|
evaluation.log_batch_errors(train_batch_errors,
|
|
log_writer,
|
|
'train',
|
|
curr_iter,
|
|
bar_plot=['kde'],
|
|
box_plot=['ade', 'fde'])
|
|
|
|
# Predict batch timesteps for evaluation dataset evaluation
|
|
eval_batch_errors = []
|
|
for scene in eval_scenes:
|
|
timesteps = scene.sample_timesteps(args.eval_batch_size)
|
|
|
|
predictions = eval_stg.predict(scene,
|
|
timesteps,
|
|
ph,
|
|
num_samples_z=100,
|
|
min_future_timesteps=ph,
|
|
max_nodes=4 * args.eval_batch_size)
|
|
|
|
eval_batch_errors.append(evaluation.compute_batch_statistics(predictions,
|
|
scene.dt,
|
|
max_hl=max_hl,
|
|
ph=ph,
|
|
node_type_enum=eval_env.NodeType,
|
|
map=scene.map))
|
|
|
|
evaluation.log_batch_errors(eval_batch_errors,
|
|
log_writer,
|
|
'eval',
|
|
curr_iter,
|
|
bar_plot=['kde'],
|
|
box_plot=['ade', 'fde'])
|
|
|
|
|
|
# Predict maximum likelihood batch timesteps for evaluation dataset evaluation
|
|
eval_batch_errors_ml = []
|
|
for scene in eval_scenes:
|
|
timesteps = scene.sample_timesteps(scene.timesteps)
|
|
|
|
predictions = eval_stg.predict(scene,
|
|
timesteps,
|
|
ph,
|
|
num_samples_z=1,
|
|
min_future_timesteps=ph,
|
|
most_likely_z=True,
|
|
most_likely_gmm=True)
|
|
|
|
eval_batch_errors_ml.append(evaluation.compute_batch_statistics(predictions,
|
|
scene.dt,
|
|
max_hl=max_hl,
|
|
ph=ph,
|
|
map=scene.map,
|
|
node_type_enum=eval_env.NodeType,
|
|
kde=False))
|
|
|
|
evaluation.log_batch_errors(eval_batch_errors_ml,
|
|
log_writer,
|
|
'eval/ml',
|
|
curr_iter)
|
|
|
|
eval_loss = []
|
|
max_scenes = np.min([len(eval_scenes), 25])
|
|
for scene in np.random.choice(eval_scenes, max_scenes):
|
|
eval_loss.append(eval_stg.eval_loss(scene, timesteps))
|
|
|
|
evaluation.log_batch_errors(eval_loss,
|
|
log_writer,
|
|
'eval/loss',
|
|
curr_iter)
|
|
|
|
|
|
else:
|
|
print('{:15} | {:10} | {:14}'.format('', '', ''),
|
|
end='', flush=True)
|
|
|
|
# Here's the newline that ends the current training information printing.
|
|
print('')
|
|
|
|
if args.save_every is not None and (curr_iter + 1) % args.save_every == 0:
|
|
model_registrar.save_models(curr_iter)
|
|
print_training_header()
|
|
|
|
|
|
def print_training_header(newline_start=False):
|
|
if newline_start:
|
|
print('')
|
|
|
|
print('Iteration | Train Loss | Eval NLL Q (IS) | Eval NLL P | Eval NLL Exact')
|
|
print('----------------------------------------------------------------------')
|
|
|
|
|
|
def memInUse():
|
|
pid = os.getpid()
|
|
py = psutil.Process(pid)
|
|
memoryUse = py.memory_info()[0] / 2. ** 30 # memory use in GB...I think
|
|
print('memory GB:', memoryUse)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|