Minor changes for trap compatibility
This commit is contained in:
parent
51d6157af9
commit
4e883511d3
4 changed files with 20 additions and 8 deletions
|
@ -163,7 +163,7 @@ parser.add_argument('--eval_every',
|
||||||
parser.add_argument('--vis_every',
|
parser.add_argument('--vis_every',
|
||||||
help='how often to visualize during training, never if None',
|
help='how often to visualize during training, never if None',
|
||||||
type=int,
|
type=int,
|
||||||
default=1)
|
default=None)
|
||||||
|
|
||||||
parser.add_argument('--save_every',
|
parser.add_argument('--save_every',
|
||||||
help='how often to save during training, never if None',
|
help='how often to save during training, never if None',
|
||||||
|
|
|
@ -44,7 +44,7 @@ class Node(object):
|
||||||
return hash((self.type, self.id))
|
return hash((self.type, self.id))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '/'.join([self.type.name, self.id])
|
return '/'.join([self.type.name, str(self.id)])
|
||||||
|
|
||||||
def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False):
|
def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -89,12 +89,15 @@ def main():
|
||||||
print('| PH: %s' % hyperparams['prediction_horizon'])
|
print('| PH: %s' % hyperparams['prediction_horizon'])
|
||||||
print('-----------------------')
|
print('-----------------------')
|
||||||
|
|
||||||
|
# TODO)) gets rid of torch/distributions/distribution.py:44: UserWarning: <class 'trajectron.model.components.gmm2d.GMM2D'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
log_writer = None
|
log_writer = None
|
||||||
model_dir = None
|
model_dir = None
|
||||||
if not args.debug:
|
if not args.debug:
|
||||||
# Create the log and model directiory if they're not present.
|
# Create the log and model directiory if they're not present.
|
||||||
model_dir = os.path.join(args.log_dir,
|
model_dir = os.path.join(args.log_dir,
|
||||||
'models_' + time.strftime('%d_%b_%Y_%H_%M_%S', time.localtime()) + args.log_tag)
|
'models_' + time.strftime('%Y%m%d_%H_%M_%S', time.localtime()) + args.log_tag)
|
||||||
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Save config to model directory
|
# Save config to model directory
|
||||||
|
@ -131,13 +134,14 @@ def main():
|
||||||
min_future_timesteps=hyperparams['prediction_horizon'],
|
min_future_timesteps=hyperparams['prediction_horizon'],
|
||||||
return_robot=not args.incl_robot_node)
|
return_robot=not args.incl_robot_node)
|
||||||
train_data_loader = dict()
|
train_data_loader = dict()
|
||||||
|
print(train_scenes)
|
||||||
for node_type_data_set in train_dataset:
|
for node_type_data_set in train_dataset:
|
||||||
if len(node_type_data_set) == 0:
|
if len(node_type_data_set) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
pin_memory=False if args.device is 'cpu' else True,
|
pin_memory=False if args.device == 'cpu' else True,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=args.preprocess_workers)
|
num_workers=args.preprocess_workers)
|
||||||
|
@ -180,7 +184,7 @@ def main():
|
||||||
|
|
||||||
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
pin_memory=False if args.eval_device is 'cpu' else True,
|
pin_memory=False if args.eval_device == 'cpu' else True,
|
||||||
batch_size=args.eval_batch_size,
|
batch_size=args.eval_batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=args.preprocess_workers)
|
num_workers=args.preprocess_workers)
|
||||||
|
@ -245,6 +249,7 @@ def main():
|
||||||
for epoch in range(1, args.train_epochs + 1):
|
for epoch in range(1, args.train_epochs + 1):
|
||||||
model_registrar.to(args.device)
|
model_registrar.to(args.device)
|
||||||
train_dataset.augment = args.augment
|
train_dataset.augment = args.augment
|
||||||
|
# print('train', curr_iter_node_type)
|
||||||
for node_type, data_loader in train_data_loader.items():
|
for node_type, data_loader in train_data_loader.items():
|
||||||
curr_iter = curr_iter_node_type[node_type]
|
curr_iter = curr_iter_node_type[node_type]
|
||||||
pbar = tqdm(data_loader, ncols=80)
|
pbar = tqdm(data_loader, ncols=80)
|
||||||
|
|
|
@ -17,7 +17,8 @@ def plot_trajectories(ax,
|
||||||
circle_edge_width=0.5,
|
circle_edge_width=0.5,
|
||||||
node_circle_size=0.3,
|
node_circle_size=0.3,
|
||||||
batch_num=0,
|
batch_num=0,
|
||||||
kde=False):
|
kde=False,
|
||||||
|
node_indexes=None):
|
||||||
|
|
||||||
cmap = ['k', 'b', 'y', 'g', 'r']
|
cmap = ['k', 'b', 'y', 'g', 'r']
|
||||||
|
|
||||||
|
@ -40,8 +41,12 @@ def plot_trajectories(ax,
|
||||||
ax=ax, shade=True, shade_lowest=False,
|
ax=ax, shade=True, shade_lowest=False,
|
||||||
color=np.random.choice(cmap), alpha=0.8)
|
color=np.random.choice(cmap), alpha=0.8)
|
||||||
|
|
||||||
|
if not node_indexes:
|
||||||
|
color = cmap[node.type.value]
|
||||||
|
else:
|
||||||
|
color = cmap[(node_indexes[node]+1) % len(cmap)]
|
||||||
ax.plot(predictions[batch_num, sample_num, :, 0], predictions[batch_num, sample_num, :, 1],
|
ax.plot(predictions[batch_num, sample_num, :, 0], predictions[batch_num, sample_num, :, 1],
|
||||||
color=cmap[node.type.value],
|
color=color,
|
||||||
linewidth=line_width, alpha=line_alpha)
|
linewidth=line_width, alpha=line_alpha)
|
||||||
|
|
||||||
ax.plot(future[:, 0],
|
ax.plot(future[:, 0],
|
||||||
|
@ -86,9 +91,11 @@ def visualize_prediction(ax,
|
||||||
histories_dict = histories_dict[ts_key]
|
histories_dict = histories_dict[ts_key]
|
||||||
futures_dict = futures_dict[ts_key]
|
futures_dict = futures_dict[ts_key]
|
||||||
|
|
||||||
|
node_indexes={node: nr for nr, node in enumerate(prediction_dict.keys())}
|
||||||
|
|
||||||
if map is not None:
|
if map is not None:
|
||||||
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
|
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
|
||||||
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs)
|
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, node_indexes=node_indexes, *kwargs)
|
||||||
|
|
||||||
|
|
||||||
def visualize_distribution(ax,
|
def visualize_distribution(ax,
|
||||||
|
|
Loading…
Reference in a new issue