From 4e883511d31c96f0a894afc714fed05e322cea5b Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Fri, 13 Dec 2024 10:38:12 +0100 Subject: [PATCH] Minor changes for trap compatibility --- trajectron/argument_parser.py | 2 +- trajectron/environment/node.py | 2 +- trajectron/train.py | 11 ++++++++--- trajectron/visualization/visualization.py | 13 ++++++++++--- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/trajectron/argument_parser.py b/trajectron/argument_parser.py index 8da1fd9..ceb1d16 100644 --- a/trajectron/argument_parser.py +++ b/trajectron/argument_parser.py @@ -163,7 +163,7 @@ parser.add_argument('--eval_every', parser.add_argument('--vis_every', help='how often to visualize during training, never if None', type=int, - default=1) + default=None) parser.add_argument('--save_every', help='how often to save during training, never if None', diff --git a/trajectron/environment/node.py b/trajectron/environment/node.py index 316888a..d349b77 100644 --- a/trajectron/environment/node.py +++ b/trajectron/environment/node.py @@ -44,7 +44,7 @@ class Node(object): return hash((self.type, self.id)) def __repr__(self): - return '/'.join([self.type.name, self.id]) + return '/'.join([self.type.name, str(self.id)]) def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False): """ diff --git a/trajectron/train.py b/trajectron/train.py index a783728..4c3ac1d 100644 --- a/trajectron/train.py +++ b/trajectron/train.py @@ -89,12 +89,15 @@ def main(): print('| PH: %s' % hyperparams['prediction_horizon']) print('-----------------------') + # TODO)) gets rid of torch/distributions/distribution.py:44: UserWarning: does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation. + warnings.filterwarnings("ignore") + log_writer = None model_dir = None if not args.debug: # Create the log and model directiory if they're not present. model_dir = os.path.join(args.log_dir, - 'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) + args.log_tag) + 'models_' + time.strftime('%Y%m%d_%H_%M_%S', time.localtime()) + args.log_tag) pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True) # Save config to model directory @@ -131,13 +134,14 @@ def main(): min_future_timesteps=hyperparams['prediction_horizon'], return_robot=not args.incl_robot_node) train_data_loader = dict() + print(train_scenes) for node_type_data_set in train_dataset: if len(node_type_data_set) == 0: continue node_type_dataloader = utils.data.DataLoader(node_type_data_set, collate_fn=collate, - pin_memory=False if args.device is 'cpu' else True, + pin_memory=False if args.device == 'cpu' else True, batch_size=args.batch_size, shuffle=True, num_workers=args.preprocess_workers) @@ -180,7 +184,7 @@ def main(): node_type_dataloader = utils.data.DataLoader(node_type_data_set, collate_fn=collate, - pin_memory=False if args.eval_device is 'cpu' else True, + pin_memory=False if args.eval_device == 'cpu' else True, batch_size=args.eval_batch_size, shuffle=True, num_workers=args.preprocess_workers) @@ -245,6 +249,7 @@ def main(): for epoch in range(1, args.train_epochs + 1): model_registrar.to(args.device) train_dataset.augment = args.augment + # print('train', curr_iter_node_type) for node_type, data_loader in train_data_loader.items(): curr_iter = curr_iter_node_type[node_type] pbar = tqdm(data_loader, ncols=80) diff --git a/trajectron/visualization/visualization.py b/trajectron/visualization/visualization.py index d28c13e..6b9b5d5 100644 --- a/trajectron/visualization/visualization.py +++ b/trajectron/visualization/visualization.py @@ -17,7 +17,8 @@ def plot_trajectories(ax, circle_edge_width=0.5, node_circle_size=0.3, batch_num=0, - kde=False): + kde=False, + node_indexes=None): cmap = ['k', 'b', 'y', 'g', 'r'] @@ -40,8 +41,12 @@ def plot_trajectories(ax, ax=ax, shade=True, shade_lowest=False, color=np.random.choice(cmap), alpha=0.8) + if not node_indexes: + color = cmap[node.type.value] + else: + color = cmap[(node_indexes[node]+1) % len(cmap)] ax.plot(predictions[batch_num, sample_num, :, 0], predictions[batch_num, sample_num, :, 1], - color=cmap[node.type.value], + color=color, linewidth=line_width, alpha=line_alpha) ax.plot(future[:, 0], @@ -86,9 +91,11 @@ def visualize_prediction(ax, histories_dict = histories_dict[ts_key] futures_dict = futures_dict[ts_key] + node_indexes={node: nr for nr, node in enumerate(prediction_dict.keys())} + if map is not None: ax.imshow(map.as_image(), origin='lower', alpha=0.5) - plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs) + plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, node_indexes=node_indexes, *kwargs) def visualize_distribution(ax,