Compare commits
3 commits
4e883511d3
...
95da3d6136
Author | SHA1 | Date | |
---|---|---|---|
|
95da3d6136 | ||
|
dd64170ac3 | ||
|
e5151e1a32 |
3 changed files with 19 additions and 4 deletions
|
@ -31,6 +31,15 @@ class Node(object):
|
||||||
|
|
||||||
self.forward_in_time_on_next_override = False
|
self.forward_in_time_on_next_override = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def first_timestep(self):
|
||||||
|
return self._first_timestep
|
||||||
|
|
||||||
|
@first_timestep.setter
|
||||||
|
def first_timestep(self, value):
|
||||||
|
self._first_timestep = value
|
||||||
|
self._last_timestep = None # reset
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return ((isinstance(other, self.__class__)
|
return ((isinstance(other, self.__class__)
|
||||||
or isinstance(self, other.__class__))
|
or isinstance(self, other.__class__))
|
||||||
|
|
|
@ -209,7 +209,7 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
if self.node not in maps:
|
if self.node not in maps:
|
||||||
# This means the node was removed (it is only being kept around because of the edge removal filter).
|
# This means the node was removed (it is only being kept around because of the edge removal filter).
|
||||||
me_params = self.hyperparams['map_encoder'][self.node_type]
|
me_params = self.hyperparams['map_encoder'][self.node_type]
|
||||||
self.TD['encoded_map'] = torch.zeros((1, me_params['output_size']))
|
self.TD['encoded_map'] = torch.zeros((1, me_params['output_size'])).to(self.TD['node_history_encoded'].device)
|
||||||
else:
|
else:
|
||||||
encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1.,
|
encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1.,
|
||||||
(mode == ModeKeys.TRAIN))
|
(mode == ModeKeys.TRAIN))
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, optim, utils
|
from torch import nn, optim, utils
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -17,6 +19,8 @@ from trajectron.model.trajectron import Trajectron
|
||||||
from trajectron.model.model_registrar import ModelRegistrar
|
from trajectron.model.model_registrar import ModelRegistrar
|
||||||
from trajectron.model.model_utils import cyclical_lr
|
from trajectron.model.model_utils import cyclical_lr
|
||||||
from trajectron.model.dataset import EnvironmentDataset, collate
|
from trajectron.model.dataset import EnvironmentDataset, collate
|
||||||
|
from trajectron.environment import Environment, Scene, Node
|
||||||
|
|
||||||
from tensorboardX import SummaryWriter
|
from tensorboardX import SummaryWriter
|
||||||
# torch.autograd.set_detect_anomaly(True)
|
# torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
||||||
|
@ -134,7 +138,7 @@ def main():
|
||||||
min_future_timesteps=hyperparams['prediction_horizon'],
|
min_future_timesteps=hyperparams['prediction_horizon'],
|
||||||
return_robot=not args.incl_robot_node)
|
return_robot=not args.incl_robot_node)
|
||||||
train_data_loader = dict()
|
train_data_loader = dict()
|
||||||
print(train_scenes)
|
logging.debug(f"{train_scenes=}")
|
||||||
for node_type_data_set in train_dataset:
|
for node_type_data_set in train_dataset:
|
||||||
if len(node_type_data_set) == 0:
|
if len(node_type_data_set) == 0:
|
||||||
continue
|
continue
|
||||||
|
@ -165,7 +169,7 @@ def main():
|
||||||
for scene in eval_env.scenes:
|
for scene in eval_env.scenes:
|
||||||
scene.add_robot_from_nodes(eval_env.robot_type)
|
scene.add_robot_from_nodes(eval_env.robot_type)
|
||||||
|
|
||||||
eval_scenes = eval_env.scenes
|
eval_scenes: List[Scene] = eval_env.scenes
|
||||||
eval_scenes_sample_probs = eval_env.scenes_freq_mult_prop if args.scene_freq_mult_eval else None
|
eval_scenes_sample_probs = eval_env.scenes_freq_mult_prop if args.scene_freq_mult_eval else None
|
||||||
|
|
||||||
eval_dataset = EnvironmentDataset(eval_env,
|
eval_dataset = EnvironmentDataset(eval_env,
|
||||||
|
@ -178,6 +182,7 @@ def main():
|
||||||
min_future_timesteps=hyperparams['prediction_horizon'],
|
min_future_timesteps=hyperparams['prediction_horizon'],
|
||||||
return_robot=not args.incl_robot_node)
|
return_robot=not args.incl_robot_node)
|
||||||
eval_data_loader = dict()
|
eval_data_loader = dict()
|
||||||
|
logging.debug(f"{eval_scenes=}")
|
||||||
for node_type_data_set in eval_dataset:
|
for node_type_data_set in eval_dataset:
|
||||||
if len(node_type_data_set) == 0:
|
if len(node_type_data_set) == 0:
|
||||||
continue
|
continue
|
||||||
|
@ -387,6 +392,7 @@ def main():
|
||||||
# Predict batch timesteps for evaluation dataset evaluation
|
# Predict batch timesteps for evaluation dataset evaluation
|
||||||
eval_batch_errors = []
|
eval_batch_errors = []
|
||||||
for scene in tqdm(eval_scenes, desc='Sample Evaluation', ncols=80):
|
for scene in tqdm(eval_scenes, desc='Sample Evaluation', ncols=80):
|
||||||
|
logging.debug(f"{scene}, {scene.timesteps=}, {len(scene.nodes)}")
|
||||||
timesteps = scene.sample_timesteps(args.eval_batch_size)
|
timesteps = scene.sample_timesteps(args.eval_batch_size)
|
||||||
|
|
||||||
predictions = eval_trajectron.predict(scene,
|
predictions = eval_trajectron.predict(scene,
|
||||||
|
@ -413,6 +419,7 @@ def main():
|
||||||
# Predict maximum likelihood batch timesteps for evaluation dataset evaluation
|
# Predict maximum likelihood batch timesteps for evaluation dataset evaluation
|
||||||
eval_batch_errors_ml = []
|
eval_batch_errors_ml = []
|
||||||
for scene in tqdm(eval_scenes, desc='MM Evaluation', ncols=80):
|
for scene in tqdm(eval_scenes, desc='MM Evaluation', ncols=80):
|
||||||
|
logging.debug(f"{scene}, {scene.timesteps=}, {len(scene.nodes)}")
|
||||||
timesteps = scene.sample_timesteps(scene.timesteps)
|
timesteps = scene.sample_timesteps(scene.timesteps)
|
||||||
|
|
||||||
predictions = eval_trajectron.predict(scene,
|
predictions = eval_trajectron.predict(scene,
|
||||||
|
@ -423,7 +430,6 @@ def main():
|
||||||
z_mode=True,
|
z_mode=True,
|
||||||
gmm_mode=True,
|
gmm_mode=True,
|
||||||
full_dist=False)
|
full_dist=False)
|
||||||
|
|
||||||
eval_batch_errors_ml.append(evaluation.compute_batch_statistics(predictions,
|
eval_batch_errors_ml.append(evaluation.compute_batch_statistics(predictions,
|
||||||
scene.dt,
|
scene.dt,
|
||||||
max_hl=max_hl,
|
max_hl=max_hl,
|
||||||
|
|
Loading…
Reference in a new issue