Minor changes for trap compatibility

This commit is contained in:
Ruben van de Ven 2024-12-13 10:38:12 +01:00
parent 51d6157af9
commit 4e883511d3
4 changed files with 20 additions and 8 deletions

View file

@ -163,7 +163,7 @@ parser.add_argument('--eval_every',
parser.add_argument('--vis_every', parser.add_argument('--vis_every',
help='how often to visualize during training, never if None', help='how often to visualize during training, never if None',
type=int, type=int,
default=1) default=None)
parser.add_argument('--save_every', parser.add_argument('--save_every',
help='how often to save during training, never if None', help='how often to save during training, never if None',

View file

@ -44,7 +44,7 @@ class Node(object):
return hash((self.type, self.id)) return hash((self.type, self.id))
def __repr__(self): def __repr__(self):
return '/'.join([self.type.name, self.id]) return '/'.join([self.type.name, str(self.id)])
def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False): def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False):
""" """

View file

@ -89,12 +89,15 @@ def main():
print('| PH: %s' % hyperparams['prediction_horizon']) print('| PH: %s' % hyperparams['prediction_horizon'])
print('-----------------------') print('-----------------------')
# TODO)) gets rid of torch/distributions/distribution.py:44: UserWarning: <class 'trajectron.model.components.gmm2d.GMM2D'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.
warnings.filterwarnings("ignore")
log_writer = None log_writer = None
model_dir = None model_dir = None
if not args.debug: if not args.debug:
# Create the log and model directiory if they're not present. # Create the log and model directiory if they're not present.
model_dir = os.path.join(args.log_dir, model_dir = os.path.join(args.log_dir,
'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) + args.log_tag) 'models_' + time.strftime('%Y%m%d_%H_%M_%S', time.localtime()) + args.log_tag)
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True) pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
# Save config to model directory # Save config to model directory
@ -131,13 +134,14 @@ def main():
min_future_timesteps=hyperparams['prediction_horizon'], min_future_timesteps=hyperparams['prediction_horizon'],
return_robot=not args.incl_robot_node) return_robot=not args.incl_robot_node)
train_data_loader = dict() train_data_loader = dict()
print(train_scenes)
for node_type_data_set in train_dataset: for node_type_data_set in train_dataset:
if len(node_type_data_set) == 0: if len(node_type_data_set) == 0:
continue continue
node_type_dataloader = utils.data.DataLoader(node_type_data_set, node_type_dataloader = utils.data.DataLoader(node_type_data_set,
collate_fn=collate, collate_fn=collate,
pin_memory=False if args.device is 'cpu' else True, pin_memory=False if args.device == 'cpu' else True,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
num_workers=args.preprocess_workers) num_workers=args.preprocess_workers)
@ -180,7 +184,7 @@ def main():
node_type_dataloader = utils.data.DataLoader(node_type_data_set, node_type_dataloader = utils.data.DataLoader(node_type_data_set,
collate_fn=collate, collate_fn=collate,
pin_memory=False if args.eval_device is 'cpu' else True, pin_memory=False if args.eval_device == 'cpu' else True,
batch_size=args.eval_batch_size, batch_size=args.eval_batch_size,
shuffle=True, shuffle=True,
num_workers=args.preprocess_workers) num_workers=args.preprocess_workers)
@ -245,6 +249,7 @@ def main():
for epoch in range(1, args.train_epochs + 1): for epoch in range(1, args.train_epochs + 1):
model_registrar.to(args.device) model_registrar.to(args.device)
train_dataset.augment = args.augment train_dataset.augment = args.augment
# print('train', curr_iter_node_type)
for node_type, data_loader in train_data_loader.items(): for node_type, data_loader in train_data_loader.items():
curr_iter = curr_iter_node_type[node_type] curr_iter = curr_iter_node_type[node_type]
pbar = tqdm(data_loader, ncols=80) pbar = tqdm(data_loader, ncols=80)

View file

@ -17,7 +17,8 @@ def plot_trajectories(ax,
circle_edge_width=0.5, circle_edge_width=0.5,
node_circle_size=0.3, node_circle_size=0.3,
batch_num=0, batch_num=0,
kde=False): kde=False,
node_indexes=None):
cmap = ['k', 'b', 'y', 'g', 'r'] cmap = ['k', 'b', 'y', 'g', 'r']
@ -40,8 +41,12 @@ def plot_trajectories(ax,
ax=ax, shade=True, shade_lowest=False, ax=ax, shade=True, shade_lowest=False,
color=np.random.choice(cmap), alpha=0.8) color=np.random.choice(cmap), alpha=0.8)
if not node_indexes:
color = cmap[node.type.value]
else:
color = cmap[(node_indexes[node]+1) % len(cmap)]
ax.plot(predictions[batch_num, sample_num, :, 0], predictions[batch_num, sample_num, :, 1], ax.plot(predictions[batch_num, sample_num, :, 0], predictions[batch_num, sample_num, :, 1],
color=cmap[node.type.value], color=color,
linewidth=line_width, alpha=line_alpha) linewidth=line_width, alpha=line_alpha)
ax.plot(future[:, 0], ax.plot(future[:, 0],
@ -86,9 +91,11 @@ def visualize_prediction(ax,
histories_dict = histories_dict[ts_key] histories_dict = histories_dict[ts_key]
futures_dict = futures_dict[ts_key] futures_dict = futures_dict[ts_key]
node_indexes={node: nr for nr, node in enumerate(prediction_dict.keys())}
if map is not None: if map is not None:
ax.imshow(map.as_image(), origin='lower', alpha=0.5) ax.imshow(map.as_image(), origin='lower', alpha=0.5)
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs) plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, node_indexes=node_indexes, *kwargs)
def visualize_distribution(ax, def visualize_distribution(ax,