Minor changes for trap compatibility
This commit is contained in:
parent
51d6157af9
commit
4e883511d3
4 changed files with 20 additions and 8 deletions
|
@ -163,7 +163,7 @@ parser.add_argument('--eval_every',
|
|||
parser.add_argument('--vis_every',
|
||||
help='how often to visualize during training, never if None',
|
||||
type=int,
|
||||
default=1)
|
||||
default=None)
|
||||
|
||||
parser.add_argument('--save_every',
|
||||
help='how often to save during training, never if None',
|
||||
|
|
|
@ -44,7 +44,7 @@ class Node(object):
|
|||
return hash((self.type, self.id))
|
||||
|
||||
def __repr__(self):
|
||||
return '/'.join([self.type.name, self.id])
|
||||
return '/'.join([self.type.name, str(self.id)])
|
||||
|
||||
def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False):
|
||||
"""
|
||||
|
|
|
@ -89,12 +89,15 @@ def main():
|
|||
print('| PH: %s' % hyperparams['prediction_horizon'])
|
||||
print('-----------------------')
|
||||
|
||||
# TODO)) gets rid of torch/distributions/distribution.py:44: UserWarning: <class 'trajectron.model.components.gmm2d.GMM2D'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
log_writer = None
|
||||
model_dir = None
|
||||
if not args.debug:
|
||||
# Create the log and model directiory if they're not present.
|
||||
model_dir = os.path.join(args.log_dir,
|
||||
'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) + args.log_tag)
|
||||
'models_' + time.strftime('%Y%m%d_%H_%M_%S', time.localtime()) + args.log_tag)
|
||||
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save config to model directory
|
||||
|
@ -131,13 +134,14 @@ def main():
|
|||
min_future_timesteps=hyperparams['prediction_horizon'],
|
||||
return_robot=not args.incl_robot_node)
|
||||
train_data_loader = dict()
|
||||
print(train_scenes)
|
||||
for node_type_data_set in train_dataset:
|
||||
if len(node_type_data_set) == 0:
|
||||
continue
|
||||
|
||||
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
||||
collate_fn=collate,
|
||||
pin_memory=False if args.device is 'cpu' else True,
|
||||
pin_memory=False if args.device == 'cpu' else True,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.preprocess_workers)
|
||||
|
@ -180,7 +184,7 @@ def main():
|
|||
|
||||
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
||||
collate_fn=collate,
|
||||
pin_memory=False if args.eval_device is 'cpu' else True,
|
||||
pin_memory=False if args.eval_device == 'cpu' else True,
|
||||
batch_size=args.eval_batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.preprocess_workers)
|
||||
|
@ -245,6 +249,7 @@ def main():
|
|||
for epoch in range(1, args.train_epochs + 1):
|
||||
model_registrar.to(args.device)
|
||||
train_dataset.augment = args.augment
|
||||
# print('train', curr_iter_node_type)
|
||||
for node_type, data_loader in train_data_loader.items():
|
||||
curr_iter = curr_iter_node_type[node_type]
|
||||
pbar = tqdm(data_loader, ncols=80)
|
||||
|
|
|
@ -17,7 +17,8 @@ def plot_trajectories(ax,
|
|||
circle_edge_width=0.5,
|
||||
node_circle_size=0.3,
|
||||
batch_num=0,
|
||||
kde=False):
|
||||
kde=False,
|
||||
node_indexes=None):
|
||||
|
||||
cmap = ['k', 'b', 'y', 'g', 'r']
|
||||
|
||||
|
@ -40,8 +41,12 @@ def plot_trajectories(ax,
|
|||
ax=ax, shade=True, shade_lowest=False,
|
||||
color=np.random.choice(cmap), alpha=0.8)
|
||||
|
||||
if not node_indexes:
|
||||
color = cmap[node.type.value]
|
||||
else:
|
||||
color = cmap[(node_indexes[node]+1) % len(cmap)]
|
||||
ax.plot(predictions[batch_num, sample_num, :, 0], predictions[batch_num, sample_num, :, 1],
|
||||
color=cmap[node.type.value],
|
||||
color=color,
|
||||
linewidth=line_width, alpha=line_alpha)
|
||||
|
||||
ax.plot(future[:, 0],
|
||||
|
@ -86,9 +91,11 @@ def visualize_prediction(ax,
|
|||
histories_dict = histories_dict[ts_key]
|
||||
futures_dict = futures_dict[ts_key]
|
||||
|
||||
node_indexes={node: nr for nr, node in enumerate(prediction_dict.keys())}
|
||||
|
||||
if map is not None:
|
||||
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
|
||||
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs)
|
||||
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, node_indexes=node_indexes, *kwargs)
|
||||
|
||||
|
||||
def visualize_distribution(ax,
|
||||
|
|
Loading…
Reference in a new issue