diff --git a/trajectron/train.py b/trajectron/train.py index 4c3ac1d..29b8a98 100644 --- a/trajectron/train.py +++ b/trajectron/train.py @@ -1,3 +1,5 @@ +import logging +from typing import List import torch from torch import nn, optim, utils import numpy as np @@ -17,6 +19,8 @@ from trajectron.model.trajectron import Trajectron from trajectron.model.model_registrar import ModelRegistrar from trajectron.model.model_utils import cyclical_lr from trajectron.model.dataset import EnvironmentDataset, collate +from trajectron.environment import Environment, Scene, Node + from tensorboardX import SummaryWriter # torch.autograd.set_detect_anomaly(True) @@ -134,7 +138,7 @@ def main(): min_future_timesteps=hyperparams['prediction_horizon'], return_robot=not args.incl_robot_node) train_data_loader = dict() - print(train_scenes) + logging.debug(f"{train_scenes=}") for node_type_data_set in train_dataset: if len(node_type_data_set) == 0: continue @@ -165,7 +169,7 @@ def main(): for scene in eval_env.scenes: scene.add_robot_from_nodes(eval_env.robot_type) - eval_scenes = eval_env.scenes + eval_scenes: List[Scene] = eval_env.scenes eval_scenes_sample_probs = eval_env.scenes_freq_mult_prop if args.scene_freq_mult_eval else None eval_dataset = EnvironmentDataset(eval_env, @@ -178,6 +182,7 @@ def main(): min_future_timesteps=hyperparams['prediction_horizon'], return_robot=not args.incl_robot_node) eval_data_loader = dict() + logging.debug(f"{eval_scenes=}") for node_type_data_set in eval_dataset: if len(node_type_data_set) == 0: continue @@ -387,6 +392,7 @@ def main(): # Predict batch timesteps for evaluation dataset evaluation eval_batch_errors = [] for scene in tqdm(eval_scenes, desc='Sample Evaluation', ncols=80): + logging.debug(f"{scene}, {scene.timesteps=}, {len(scene.nodes)}") timesteps = scene.sample_timesteps(args.eval_batch_size) predictions = eval_trajectron.predict(scene, @@ -413,6 +419,7 @@ def main(): # Predict maximum likelihood batch timesteps for evaluation dataset evaluation eval_batch_errors_ml = [] for scene in tqdm(eval_scenes, desc='MM Evaluation', ncols=80): + logging.debug(f"{scene}, {scene.timesteps=}, {len(scene.nodes)}") timesteps = scene.sample_timesteps(scene.timesteps) predictions = eval_trajectron.predict(scene, @@ -423,7 +430,6 @@ def main(): z_mode=True, gmm_mode=True, full_dist=False) - eval_batch_errors_ml.append(evaluation.compute_batch_statistics(predictions, scene.dt, max_hl=max_hl,