Adding ability to run on streaming data.
This commit is contained in:
parent
b99e07dcab
commit
174198bd4f
14 changed files with 305 additions and 125 deletions
16
README.md
16
README.md
|
@ -130,6 +130,22 @@ If you instead wanted to evaluate a trained model's performance on forecasting p
|
||||||
|
|
||||||
These scripts will produce csv files in the `results` directory which can then be analyzed in the `NuScenes Quantitative.ipynb` notebook.
|
These scripts will produce csv files in the `results` directory which can then be analyzed in the `NuScenes Quantitative.ipynb` notebook.
|
||||||
|
|
||||||
|
## Online Execution ##
|
||||||
|
As of December 2020, this repository includes an "online" running capability. In addition to the regular batched mode for training and testing, Trajectron++ can now be executed online on streaming data!
|
||||||
|
|
||||||
|
The `trajectron/test_online.py` script shows how to use it, and can be run as follows (depending on the desired model).
|
||||||
|
|
||||||
|
| Model | Command | File Changes |
|
||||||
|
|-------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|
|
||||||
|
| Base | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl | Line 110: `'vel_ee'` |
|
||||||
|
| +Dynamics Integration | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl | Line 110: `'int_ee'` |
|
||||||
|
| +Dynamics Integration, Maps | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl --map_encoding | Line 110: `'int_ee_me'` |
|
||||||
|
| +Dynamics Integration, Maps, Robot Future | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl --map_encoding --incl_robot_node | Line 110: `'robot'` |
|
||||||
|
|
||||||
|
Further, lines 145-151 can be changed to choose different scenes and starting timesteps.
|
||||||
|
|
||||||
|
During running, each prediction will be iteratively visualized and saved in a `pred_figs/` folder within the specified model folder. For example, if the script loads the `int_ee` version of Trajectron++ then generated figures will be saved to `experiments/nuScenes/models/int_ee/pred_figs/`.
|
||||||
|
|
||||||
## Datasets ##
|
## Datasets ##
|
||||||
|
|
||||||
### ETH and UCY Pedestrian Datasets ###
|
### ETH and UCY Pedestrian Datasets ###
|
||||||
|
|
|
@ -46,7 +46,7 @@ class Node(object):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '/'.join([self.type.name, self.id])
|
return '/'.join([self.type.name, self.id])
|
||||||
|
|
||||||
def overwrite_data(self, data, forward_in_time_on_next_overwrite=False):
|
def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False):
|
||||||
"""
|
"""
|
||||||
This function hard overwrites the data matrix. When using it you have to make sure that the columns
|
This function hard overwrites the data matrix. When using it you have to make sure that the columns
|
||||||
in the new data matrix correspond to the old structure. As well as setting first_timestep.
|
in the new data matrix correspond to the old structure. As well as setting first_timestep.
|
||||||
|
@ -55,7 +55,11 @@ class Node(object):
|
||||||
:param forward_in_time_on_next_overwrite: On the !!NEXT!! call of overwrite_data first_timestep will be increased.
|
:param forward_in_time_on_next_overwrite: On the !!NEXT!! call of overwrite_data first_timestep will be increased.
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
if header is None:
|
||||||
self.data.data = data
|
self.data.data = data
|
||||||
|
else:
|
||||||
|
self.data = DoubleHeaderNumpyArray(data, header)
|
||||||
|
|
||||||
self._last_timestep = None
|
self._last_timestep = None
|
||||||
if self.forward_in_time_on_next_override:
|
if self.forward_in_time_on_next_override:
|
||||||
self.first_timestep += 1
|
self.first_timestep += 1
|
||||||
|
|
|
@ -25,23 +25,28 @@ class Scene(object):
|
||||||
self.non_aug_scene = non_aug_scene
|
self.non_aug_scene = non_aug_scene
|
||||||
|
|
||||||
def add_robot_from_nodes(self, robot_type):
|
def add_robot_from_nodes(self, robot_type):
|
||||||
nodes_list = [node for node in self.nodes if node.type == robot_type]
|
scenes = [self]
|
||||||
|
if hasattr(self, 'augmented'):
|
||||||
|
scenes += self.augmented
|
||||||
|
|
||||||
|
for scn in scenes:
|
||||||
|
nodes_list = [node for node in scn.nodes if node.type == robot_type]
|
||||||
non_overlapping_nodes = MultiNode.find_non_overlapping_nodes(nodes_list, min_timesteps=3)
|
non_overlapping_nodes = MultiNode.find_non_overlapping_nodes(nodes_list, min_timesteps=3)
|
||||||
self.robot = MultiNode(robot_type, 'ROBOT', non_overlapping_nodes, is_robot=True)
|
scn.robot = MultiNode(robot_type, 'ROBOT', non_overlapping_nodes, is_robot=True)
|
||||||
|
|
||||||
for node in non_overlapping_nodes:
|
for node in non_overlapping_nodes:
|
||||||
self.nodes.remove(node)
|
scn.nodes.remove(node)
|
||||||
self.nodes.append(self.robot)
|
scn.nodes.append(scn.robot)
|
||||||
|
|
||||||
def get_clipped_pos_dict(self, timestep, state):
|
def get_clipped_input_dict(self, timestep, state):
|
||||||
pos_dict = dict()
|
input_dict = dict()
|
||||||
existing_nodes = self.get_nodes_clipped_at_time(timesteps=np.array([timestep]),
|
existing_nodes = self.get_nodes_clipped_at_time(timesteps=np.array([timestep]),
|
||||||
state=state)
|
state=state)
|
||||||
tr_scene = np.array([timestep, timestep])
|
tr_scene = np.array([timestep, timestep])
|
||||||
for node in existing_nodes:
|
for node in existing_nodes:
|
||||||
pos_dict[node] = node.get(tr_scene, {'position': ['x', 'y']})
|
input_dict[node] = node.get(tr_scene, state[node.type])
|
||||||
|
|
||||||
return pos_dict
|
return input_dict
|
||||||
|
|
||||||
def get_scene_graph(self,
|
def get_scene_graph(self,
|
||||||
timestep,
|
timestep,
|
||||||
|
@ -160,6 +165,7 @@ class Scene(object):
|
||||||
return clipped_nodes
|
return clipped_nodes
|
||||||
|
|
||||||
tr_scene = np.array([timesteps.min(), timesteps.max()])
|
tr_scene = np.array([timesteps.min(), timesteps.max()])
|
||||||
|
data_header_memo = dict()
|
||||||
for node in all_nodes:
|
for node in all_nodes:
|
||||||
if isinstance(node, MultiNode):
|
if isinstance(node, MultiNode):
|
||||||
copied_node = copy.deepcopy(node.get_node_at_timesteps(tr_scene))
|
copied_node = copy.deepcopy(node.get_node_at_timesteps(tr_scene))
|
||||||
|
@ -168,7 +174,16 @@ class Scene(object):
|
||||||
copied_node = copy.deepcopy(node)
|
copied_node = copy.deepcopy(node)
|
||||||
|
|
||||||
clipped_value = node.get(tr_scene, state[node.type])
|
clipped_value = node.get(tr_scene, state[node.type])
|
||||||
copied_node.overwrite_data(clipped_value)
|
|
||||||
|
if node.type not in data_header_memo:
|
||||||
|
data_header = list()
|
||||||
|
for quantity, values in state[node.type].items():
|
||||||
|
for value in values:
|
||||||
|
data_header.append((quantity, value))
|
||||||
|
|
||||||
|
data_header_memo[node.type] = data_header
|
||||||
|
|
||||||
|
copied_node.overwrite_data(clipped_value, data_header_memo[node.type])
|
||||||
copied_node.first_timestep = tr_scene[0]
|
copied_node.first_timestep = tr_scene[0]
|
||||||
|
|
||||||
clipped_nodes.append(copied_node)
|
clipped_nodes.append(copied_node)
|
||||||
|
|
|
@ -292,7 +292,7 @@ class SceneGraph(object):
|
||||||
other_types = set(node.type for node in other.nodes)
|
other_types = set(node.type for node in other.nodes)
|
||||||
all_node_types = our_types | other_types
|
all_node_types = our_types | other_types
|
||||||
|
|
||||||
new_neighbors = defaultdict(dict)
|
new_neighbors = defaultdict(lambda: defaultdict(set))
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
if node in removed_nodes:
|
if node in removed_nodes:
|
||||||
continue
|
continue
|
||||||
|
@ -306,9 +306,9 @@ class SceneGraph(object):
|
||||||
for node_type in our_types:
|
for node_type in our_types:
|
||||||
neighbors = self.get_neighbors(node, node_type)
|
neighbors = self.get_neighbors(node, node_type)
|
||||||
if len(neighbors) > 0:
|
if len(neighbors) > 0:
|
||||||
new_neighbors[node] = {DirectedEdge.get_edge_type(node, Node(node_type, None, None)): set(neighbors)}
|
new_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = set(neighbors)
|
||||||
|
|
||||||
removed_neighbors = defaultdict(dict)
|
removed_neighbors = defaultdict(lambda: defaultdict(set))
|
||||||
for node in other.nodes:
|
for node in other.nodes:
|
||||||
if node in removed_nodes:
|
if node in removed_nodes:
|
||||||
continue
|
continue
|
||||||
|
@ -322,13 +322,13 @@ class SceneGraph(object):
|
||||||
for node_type in other_types:
|
for node_type in other_types:
|
||||||
neighbors = other.get_neighbors(node, node_type)
|
neighbors = other.get_neighbors(node, node_type)
|
||||||
if len(neighbors) > 0:
|
if len(neighbors) > 0:
|
||||||
removed_neighbors[node] = {DirectedEdge.get_edge_type(node, Node(node_type, None, None)): set(neighbors)}
|
removed_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = set(neighbors)
|
||||||
|
|
||||||
return new_nodes, removed_nodes, new_neighbors, removed_neighbors
|
return new_nodes, removed_nodes, new_neighbors, removed_neighbors
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from data import NodeTypeEnum
|
from environment import NodeTypeEnum
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# # # # # # # # # # # # # # # # #
|
# # # # # # # # # # # # # # # # #
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
from .dataset import EnvironmentDataset, NodeTypeDataset
|
from .dataset import EnvironmentDataset, NodeTypeDataset
|
||||||
from .preprocessing import collate, get_node_timestep_data, get_timesteps_data, restore
|
from .preprocessing import collate, get_node_timestep_data, get_timesteps_data, restore, get_relative_robot_traj
|
||||||
|
|
|
@ -159,8 +159,9 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state,
|
||||||
else:
|
else:
|
||||||
robot = scene.robot
|
robot = scene.robot
|
||||||
robot_type = robot.type
|
robot_type = robot.type
|
||||||
robot_traj = robot.get(timestep_range_r, state[robot_type], padding=0.0)
|
robot_traj = robot.get(timestep_range_r, state[robot_type], padding=np.nan)
|
||||||
robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type)
|
robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type)
|
||||||
|
robot_traj_st_t[torch.isnan(robot_traj_st_t)] = 0.0
|
||||||
|
|
||||||
# Map
|
# Map
|
||||||
map_tuple = None
|
map_tuple = None
|
||||||
|
|
|
@ -62,6 +62,12 @@ class Unicycle(Dynamic):
|
||||||
ph = control_samples.shape[-2]
|
ph = control_samples.shape[-2]
|
||||||
p_0 = self.initial_conditions['pos'].unsqueeze(1)
|
p_0 = self.initial_conditions['pos'].unsqueeze(1)
|
||||||
v_0 = self.initial_conditions['vel'].unsqueeze(1)
|
v_0 = self.initial_conditions['vel'].unsqueeze(1)
|
||||||
|
|
||||||
|
# In case the input is batched because of the robot in online use we repeat this to match the batch size of x.
|
||||||
|
if p_0.size()[0] != x.size()[0]:
|
||||||
|
p_0 = p_0.repeat(x.size()[0], 1, 1)
|
||||||
|
v_0 = v_0.repeat(x.size()[0], 1, 1)
|
||||||
|
|
||||||
phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0])
|
phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0])
|
||||||
|
|
||||||
phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1)))
|
phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1)))
|
||||||
|
@ -193,6 +199,12 @@ class Unicycle(Dynamic):
|
||||||
ph = control_dist_dphi_a.mus.shape[-3]
|
ph = control_dist_dphi_a.mus.shape[-3]
|
||||||
p_0 = self.initial_conditions['pos'].unsqueeze(1)
|
p_0 = self.initial_conditions['pos'].unsqueeze(1)
|
||||||
v_0 = self.initial_conditions['vel'].unsqueeze(1)
|
v_0 = self.initial_conditions['vel'].unsqueeze(1)
|
||||||
|
|
||||||
|
# In case the input is batched because of the robot in online use we repeat this to match the batch size of x.
|
||||||
|
if p_0.size()[0] != x.size()[0]:
|
||||||
|
p_0 = p_0.repeat(x.size()[0], 1, 1)
|
||||||
|
v_0 = v_0.repeat(x.size()[0], 1, 1)
|
||||||
|
|
||||||
phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0])
|
phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0])
|
||||||
|
|
||||||
phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1)))
|
phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1)))
|
||||||
|
|
|
@ -370,6 +370,7 @@ class MultimodalGenerativeCVAE(object):
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
torch.Tensor):
|
torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Encodes input and output tensors for node and robot.
|
Encodes input and output tensors for node and robot.
|
||||||
|
|
|
@ -6,6 +6,7 @@ import numpy as np
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
from model.components import *
|
from model.components import *
|
||||||
from model.model_utils import *
|
from model.model_utils import *
|
||||||
|
from model.dataset import get_relative_robot_traj
|
||||||
import model.dynamics as dynamic_module
|
import model.dynamics as dynamic_module
|
||||||
from model.mgcvae import MultimodalGenerativeCVAE
|
from model.mgcvae import MultimodalGenerativeCVAE
|
||||||
from environment.scene_graph import DirectedEdge
|
from environment.scene_graph import DirectedEdge
|
||||||
|
@ -48,7 +49,8 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
|
|
||||||
dynamic_class = getattr(dynamic_module, self.hyperparams['dynamic'][self.node_type]['name'])
|
dynamic_class = getattr(dynamic_module, self.hyperparams['dynamic'][self.node_type]['name'])
|
||||||
dyn_limits = hyperparams['dynamic'][self.node_type]['limits']
|
dyn_limits = hyperparams['dynamic'][self.node_type]['limits']
|
||||||
self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device, self.model_registrar, self.x_size)
|
self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device,
|
||||||
|
self.model_registrar, self.x_size, self.node_type)
|
||||||
|
|
||||||
def create_graphical_model(self):
|
def create_graphical_model(self):
|
||||||
"""
|
"""
|
||||||
|
@ -127,11 +129,13 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
if len(self.scene_graph.get_neighbors(self.node, self._get_other_node_type_from_edge(edge_type))) == 0:
|
if len(self.scene_graph.get_neighbors(self.node, self._get_other_node_type_from_edge(edge_type))) == 0:
|
||||||
del self.node_modules[edge_type + '/edge_encoder']
|
del self.node_modules[edge_type + '/edge_encoder']
|
||||||
|
|
||||||
def obtain_encoded_tensors(self, mode, inputs, inputs_st, inputs_np, robot_present_and_future) -> (torch.Tensor,
|
def obtain_encoded_tensors(self,
|
||||||
torch.Tensor,
|
mode,
|
||||||
torch.Tensor,
|
inputs,
|
||||||
torch.Tensor,
|
inputs_st,
|
||||||
torch.Tensor):
|
inputs_np,
|
||||||
|
robot_present_and_future,
|
||||||
|
maps):
|
||||||
x, x_r_t, y_r = None, None, None
|
x, x_r_t, y_r = None, None, None
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
|
|
||||||
|
@ -139,15 +143,19 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
our_inputs_st = inputs_st[self.node]
|
our_inputs_st = inputs_st[self.node]
|
||||||
|
|
||||||
initial_dynamics = dict()
|
initial_dynamics = dict()
|
||||||
initial_dynamics['pos'] = our_inputs_st[:, 0:2] # TODO: Generalize
|
initial_dynamics['pos'] = our_inputs[:, 0:2] # TODO: Generalize
|
||||||
initial_dynamics['vel'] = our_inputs_st[:, 2:4] # TODO: Generalize
|
initial_dynamics['vel'] = our_inputs[:, 2:4] # TODO: Generalize
|
||||||
self.dynamic.set_initial_condition(initial_dynamics)
|
self.dynamic.set_initial_condition(initial_dynamics)
|
||||||
|
|
||||||
#########################################
|
#########################################
|
||||||
# Provide basic information to encoders #
|
# Provide basic information to encoders #
|
||||||
#########################################
|
#########################################
|
||||||
if self.hyperparams['incl_robot_node'] and self.robot is not None:
|
if self.hyperparams['incl_robot_node'] and self.robot is not None:
|
||||||
x_r_t, y_r = self.get_relative_robot_traj(our_inputs, robot_present_and_future, self.robot.type)
|
robot_present_and_future_st = get_relative_robot_traj(self.env, self.state,
|
||||||
|
our_inputs, robot_present_and_future,
|
||||||
|
self.node.type, self.robot.type)
|
||||||
|
x_r_t = robot_present_and_future_st[..., 0, :]
|
||||||
|
y_r = robot_present_and_future_st[..., 1:, :]
|
||||||
|
|
||||||
##################
|
##################
|
||||||
# Encode History #
|
# Encode History #
|
||||||
|
@ -194,6 +202,21 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
self.TD = {'node_history_encoded': node_history_encoded,
|
self.TD = {'node_history_encoded': node_history_encoded,
|
||||||
'total_edge_influence': total_edge_influence}
|
'total_edge_influence': total_edge_influence}
|
||||||
|
|
||||||
|
################
|
||||||
|
# Map Encoding #
|
||||||
|
################
|
||||||
|
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
|
||||||
|
if self.node not in maps:
|
||||||
|
# This means the node was removed (it is only being kept around because of the edge removal filter).
|
||||||
|
me_params = self.hyperparams['map_encoder'][self.node_type]
|
||||||
|
self.TD['encoded_map'] = torch.zeros((1, me_params['output_size']))
|
||||||
|
else:
|
||||||
|
encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1.,
|
||||||
|
(mode == ModeKeys.TRAIN))
|
||||||
|
do = self.hyperparams['map_encoder'][self.node_type]['dropout']
|
||||||
|
encoded_map = F.dropout(encoded_map, do, training=(mode == ModeKeys.TRAIN))
|
||||||
|
self.TD['encoded_map'] = encoded_map
|
||||||
|
|
||||||
######################################
|
######################################
|
||||||
# Concatenate Encoder Outputs into x #
|
# Concatenate Encoder Outputs into x #
|
||||||
######################################
|
######################################
|
||||||
|
@ -207,6 +230,8 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
node_history_encoded = TD['node_history_encoded']
|
node_history_encoded = TD['node_history_encoded']
|
||||||
if self.hyperparams['edge_encoding']:
|
if self.hyperparams['edge_encoding']:
|
||||||
total_edge_influence = TD['total_edge_influence']
|
total_edge_influence = TD['total_edge_influence']
|
||||||
|
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
|
||||||
|
encoded_map = TD['encoded_map']
|
||||||
|
|
||||||
if (self.hyperparams['incl_robot_node']
|
if (self.hyperparams['incl_robot_node']
|
||||||
and self.robot is not None
|
and self.robot is not None
|
||||||
|
@ -222,6 +247,8 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
node_history_encoded = TD['node_history_encoded'].repeat(robot_future_st.size()[0], 1)
|
node_history_encoded = TD['node_history_encoded'].repeat(robot_future_st.size()[0], 1)
|
||||||
if self.hyperparams['edge_encoding']:
|
if self.hyperparams['edge_encoding']:
|
||||||
total_edge_influence = TD['total_edge_influence'].repeat(robot_future_st.size()[0], 1)
|
total_edge_influence = TD['total_edge_influence'].repeat(robot_future_st.size()[0], 1)
|
||||||
|
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
|
||||||
|
encoded_map = TD['encoded_map'].repeat(robot_future_st.size()[0], 1)
|
||||||
|
|
||||||
elif self.hyperparams['incl_robot_node'] and self.robot is not None:
|
elif self.hyperparams['incl_robot_node'] and self.robot is not None:
|
||||||
# Four times because we're trying to mimic a bi-directional RNN's output (which is c and h from both ends).
|
# Four times because we're trying to mimic a bi-directional RNN's output (which is c and h from both ends).
|
||||||
|
@ -239,6 +266,9 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
if self.hyperparams['incl_robot_node'] and self.robot is not None:
|
if self.hyperparams['incl_robot_node'] and self.robot is not None:
|
||||||
x_concat_list.append(robot_future_encoder) # [bs/nbs, 4*enc_rnn_dim_history]
|
x_concat_list.append(robot_future_encoder) # [bs/nbs, 4*enc_rnn_dim_history]
|
||||||
|
|
||||||
|
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
|
||||||
|
x_concat_list.append(encoded_map) # [bs/nbs, CNN output size]
|
||||||
|
|
||||||
return torch.cat(x_concat_list, dim=1)
|
return torch.cat(x_concat_list, dim=1)
|
||||||
|
|
||||||
def encode_node_history(self, inputs_st):
|
def encode_node_history(self, inputs_st):
|
||||||
|
@ -258,13 +288,19 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
edge_states_list = list() # list of [#of neighbors, max_ht, state_dim]
|
edge_states_list = list() # list of [#of neighbors, max_ht, state_dim]
|
||||||
neighbor_states = list()
|
neighbor_states = list()
|
||||||
|
|
||||||
rel_state = inputs[self.node].cpu().numpy()
|
orig_rel_state = inputs[self.node].cpu().numpy()
|
||||||
for node in connected_nodes[0]:
|
for node in connected_nodes[0]:
|
||||||
neighbor_state_np = inputs_np[node]
|
neighbor_state_np = inputs_np[node]
|
||||||
|
|
||||||
# Make State relative to node
|
# Make State relative to node
|
||||||
_, std = self.env.get_standardize_params(self.state[node.type], node_type=node.type)
|
_, std = self.env.get_standardize_params(self.state[node.type], node_type=node.type)
|
||||||
std[0:2] = self.env.attention_radius[edge_type_tuple]
|
std[0:2] = self.env.attention_radius[edge_type_tuple]
|
||||||
|
|
||||||
|
# TODO: This all makes the unsafe assumption that the first n dims
|
||||||
|
# refer to the same quantities even for different agent types!
|
||||||
|
equal_dims = np.min((neighbor_state_np.shape[-1], orig_rel_state.shape[-1]))
|
||||||
|
rel_state = np.zeros_like(neighbor_state_np)
|
||||||
|
rel_state[..., :equal_dims] = orig_rel_state[..., :equal_dims]
|
||||||
neighbor_state_np_st = self.env.standardize(neighbor_state_np,
|
neighbor_state_np_st = self.env.standardize(neighbor_state_np,
|
||||||
self.state[node.type],
|
self.state[node.type],
|
||||||
node_type=node.type,
|
node_type=node.type,
|
||||||
|
@ -333,7 +369,7 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
else:
|
else:
|
||||||
return outputs[:, 0, :] # [bs, enc_rnn_dim]
|
return outputs[:, 0, :] # [bs, enc_rnn_dim]
|
||||||
|
|
||||||
def encoder_forward(self, inputs, inputs_st, inputs_np, robot_present_and_future=None):
|
def encoder_forward(self, inputs, inputs_st, inputs_np, robot_present_and_future=None, maps=None):
|
||||||
# Always predicting with the online model.
|
# Always predicting with the online model.
|
||||||
mode = ModeKeys.PREDICT
|
mode = ModeKeys.PREDICT
|
||||||
|
|
||||||
|
@ -341,7 +377,9 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
inputs,
|
inputs,
|
||||||
inputs_st,
|
inputs_st,
|
||||||
inputs_np,
|
inputs_np,
|
||||||
robot_present_and_future)
|
robot_present_and_future,
|
||||||
|
maps)
|
||||||
|
self.n_s_t0 = inputs_st[self.node]
|
||||||
|
|
||||||
self.latent.p_dist = self.p_z_x(mode, self.x)
|
self.latent.p_dist = self.p_z_x(mode, self.x)
|
||||||
|
|
||||||
|
@ -361,27 +399,32 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
if (self.hyperparams['incl_robot_node']
|
if (self.hyperparams['incl_robot_node']
|
||||||
and self.robot is not None
|
and self.robot is not None
|
||||||
and robot_present_and_future is not None):
|
and robot_present_and_future is not None):
|
||||||
x_nr_t, y_r = self.get_relative_robot_traj(
|
our_inputs = torch.tensor(self.node.get(np.array([self.node.last_timestep]),
|
||||||
torch.tensor(self.node.get(np.array([self.node.last_timestep]),
|
|
||||||
self.state[self.node.type],
|
self.state[self.node.type],
|
||||||
padding=0.0),
|
padding=0.0),
|
||||||
dtype=torch.float,
|
dtype=torch.float,
|
||||||
device=self.device),
|
device=self.device)
|
||||||
robot_present_and_future,
|
robot_present_and_future_st = get_relative_robot_traj(self.env, self.state,
|
||||||
self.robot.type)
|
our_inputs, robot_present_and_future,
|
||||||
|
self.node.type, self.robot.type)
|
||||||
|
x_nr_t = robot_present_and_future_st[..., 0, :]
|
||||||
|
y_r = robot_present_and_future_st[..., 1:, :]
|
||||||
self.x = self.create_encoder_rep(mode, self.TD, x_nr_t, y_r)
|
self.x = self.create_encoder_rep(mode, self.TD, x_nr_t, y_r)
|
||||||
self.latent.p_dist = self.p_z_x(mode, self.x)
|
self.latent.p_dist = self.p_z_x(mode, self.x)
|
||||||
|
|
||||||
|
# Making sure n_s_t0 has the same batch size as x_nr_t
|
||||||
|
self.n_s_t0 = self.n_s_t0[[0]].repeat(x_nr_t.size()[0], 1)
|
||||||
|
|
||||||
z, num_samples, num_components = self.latent.sample_p(num_samples,
|
z, num_samples, num_components = self.latent.sample_p(num_samples,
|
||||||
mode,
|
mode,
|
||||||
most_likely_z=z_mode,
|
most_likely_z=z_mode,
|
||||||
full_dist=full_dist,
|
full_dist=full_dist,
|
||||||
all_z_sep=all_z_sep)
|
all_z_sep=all_z_sep)
|
||||||
|
|
||||||
_, our_sampled_future = self.p_y_xz(mode, self.x, x_nr_t, y_r, z,
|
y_dist, our_sampled_future = self.p_y_xz(mode, self.x, x_nr_t, y_r, self.n_s_t0, z,
|
||||||
prediction_horizon,
|
prediction_horizon,
|
||||||
num_samples,
|
num_samples,
|
||||||
num_components,
|
num_components,
|
||||||
gmm_mode)
|
gmm_mode)
|
||||||
|
|
||||||
return our_sampled_future
|
return y_dist, our_sampled_future
|
||||||
|
|
|
@ -4,7 +4,7 @@ from collections import Counter
|
||||||
from model.trajectron import Trajectron
|
from model.trajectron import Trajectron
|
||||||
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
|
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
|
||||||
from model.model_utils import ModeKeys
|
from model.model_utils import ModeKeys
|
||||||
from data import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
|
from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
|
||||||
|
|
||||||
|
|
||||||
class OnlineTrajectron(Trajectron):
|
class OnlineTrajectron(Trajectron):
|
||||||
|
@ -58,10 +58,11 @@ class OnlineTrajectron(Trajectron):
|
||||||
|
|
||||||
# Fast-forwarding ourselves to the initial timestep, without running any of the underlying models.
|
# Fast-forwarding ourselves to the initial timestep, without running any of the underlying models.
|
||||||
for timestep in range(init_timestep + 1):
|
for timestep in range(init_timestep + 1):
|
||||||
self.incremental_forward(self.env.scenes[0].get_clipped_pos_dict(timestep, self.hyperparams['state']),
|
self.incremental_forward(self.env.scenes[0].get_clipped_input_dict(timestep, self.hyperparams['state']),
|
||||||
run_models=False)
|
maps=None, run_models=False)
|
||||||
|
|
||||||
def incremental_forward(self, new_inputs_dict,
|
def incremental_forward(self, new_inputs_dict,
|
||||||
|
maps,
|
||||||
prediction_horizon=0,
|
prediction_horizon=0,
|
||||||
num_samples=0,
|
num_samples=0,
|
||||||
robot_present_and_future=None,
|
robot_present_and_future=None,
|
||||||
|
@ -80,7 +81,8 @@ class OnlineTrajectron(Trajectron):
|
||||||
for node, new_input in new_inputs_dict.items():
|
for node, new_input in new_inputs_dict.items():
|
||||||
if node not in self.node_data:
|
if node not in self.node_data:
|
||||||
self.node_data[node] = RingBuffer(capacity=self.RING_CAPACITY,
|
self.node_data[node] = RingBuffer(capacity=self.RING_CAPACITY,
|
||||||
dtype=(float, len(self.state[node.type]['position'])))
|
dtype=(float, sum(len(self.state[node.type][k])
|
||||||
|
for k in self.state[node.type])))
|
||||||
self.node_data[node].append(new_input)
|
self.node_data[node].append(new_input)
|
||||||
|
|
||||||
if node in self.removed_nodes:
|
if node in self.removed_nodes:
|
||||||
|
@ -101,17 +103,10 @@ class OnlineTrajectron(Trajectron):
|
||||||
# that when it's passed through the LSTMs, the hidden state keeps propagating but the input plays no role
|
# that when it's passed through the LSTMs, the hidden state keeps propagating but the input plays no role
|
||||||
# (the NaNs get converted to zeros later on).
|
# (the NaNs get converted to zeros later on).
|
||||||
for node in self.removed_nodes:
|
for node in self.removed_nodes:
|
||||||
self.node_data[node].append(np.full((1, 2), np.nan))
|
self.node_data[node].append(np.full((1, self.node_data[node].shape[1]), np.nan))
|
||||||
|
|
||||||
for node in self.node_data:
|
for node in self.node_data:
|
||||||
x = self.node_data[node][:, 0]
|
node.overwrite_data(self.node_data[node], None,
|
||||||
y = self.node_data[node][:, 1]
|
|
||||||
vx = derivative_of(x, self.env.scenes[0].dt)
|
|
||||||
vy = derivative_of(y, self.env.scenes[0].dt)
|
|
||||||
ax = derivative_of(vx, self.env.scenes[0].dt)
|
|
||||||
ay = derivative_of(vy, self.env.scenes[0].dt)
|
|
||||||
new_node_data = np.stack([x, y, vx, vy, ax, ay], axis=-1)
|
|
||||||
node.overwrite_data(new_node_data,
|
|
||||||
forward_in_time_on_next_overwrite=(self.node_data[node].shape[0]
|
forward_in_time_on_next_overwrite=(self.node_data[node].shape[0]
|
||||||
== self.RING_CAPACITY))
|
== self.RING_CAPACITY))
|
||||||
|
|
||||||
|
@ -139,14 +134,15 @@ class OnlineTrajectron(Trajectron):
|
||||||
|
|
||||||
# These next 2 for loops add or remove entire node models.
|
# These next 2 for loops add or remove entire node models.
|
||||||
for node in new_nodes:
|
for node in new_nodes:
|
||||||
if node.is_robot:
|
if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys():
|
||||||
|
# Only deal with Models for NodeTypes we want to predict
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self._add_node_model(node)
|
self._add_node_model(node)
|
||||||
self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors)
|
self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors)
|
||||||
|
|
||||||
for node in removed_nodes:
|
for node in removed_nodes:
|
||||||
if node.is_robot:
|
if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self._remove_node_model(node)
|
self._remove_node_model(node)
|
||||||
|
@ -157,7 +153,8 @@ class OnlineTrajectron(Trajectron):
|
||||||
inputs_st = dict()
|
inputs_st = dict()
|
||||||
inputs_np = dict()
|
inputs_np = dict()
|
||||||
|
|
||||||
iter_list = list(self.node_models_dict.keys())
|
iter_list = list(self.node_models_dict.keys()) + [node for node in new_inputs_dict
|
||||||
|
if node.type not in self.pred_state.keys()]
|
||||||
if self.env.scenes[0].robot is not None:
|
if self.env.scenes[0].robot is not None:
|
||||||
iter_list.append(self.env.scenes[0].robot)
|
iter_list.append(self.env.scenes[0].robot)
|
||||||
|
|
||||||
|
@ -191,12 +188,15 @@ class OnlineTrajectron(Trajectron):
|
||||||
robot_present_and_future = robot_present_and_future[np.newaxis, :]
|
robot_present_and_future = robot_present_and_future[np.newaxis, :]
|
||||||
|
|
||||||
assert robot_present_and_future.shape[1] == prediction_horizon + 1
|
assert robot_present_and_future.shape[1] == prediction_horizon + 1
|
||||||
|
robot_present_and_future = torch.tensor(robot_present_and_future,
|
||||||
|
dtype=torch.float, device=self.device)
|
||||||
|
|
||||||
for node in self.node_models_dict:
|
for node in self.node_models_dict:
|
||||||
self.node_models_dict[node].encoder_forward(inputs,
|
self.node_models_dict[node].encoder_forward(inputs,
|
||||||
inputs_st,
|
inputs_st,
|
||||||
inputs_np,
|
inputs_np,
|
||||||
robot_present_and_future)
|
robot_present_and_future,
|
||||||
|
maps)
|
||||||
|
|
||||||
# If num_predicted_timesteps or num_samples == 0 then do not run the decoder at all,
|
# If num_predicted_timesteps or num_samples == 0 then do not run the decoder at all,
|
||||||
# just update the encoder LSTMs.
|
# just update the encoder LSTMs.
|
||||||
|
@ -220,7 +220,7 @@ class OnlineTrajectron(Trajectron):
|
||||||
full_dist=False,
|
full_dist=False,
|
||||||
all_z_sep=False):
|
all_z_sep=False):
|
||||||
model = self.node_models_dict[node]
|
model = self.node_models_dict[node]
|
||||||
predictions = model.decoder_forward(num_predicted_timesteps,
|
prediction_dist, predictions_uns = model.decoder_forward(num_predicted_timesteps,
|
||||||
num_samples,
|
num_samples,
|
||||||
robot_present_and_future=robot_present_and_future,
|
robot_present_and_future=robot_present_and_future,
|
||||||
z_mode=z_mode,
|
z_mode=z_mode,
|
||||||
|
@ -228,13 +228,10 @@ class OnlineTrajectron(Trajectron):
|
||||||
full_dist=full_dist,
|
full_dist=full_dist,
|
||||||
all_z_sep=all_z_sep)
|
all_z_sep=all_z_sep)
|
||||||
|
|
||||||
predictions_uns = self.env.unstandardize(predictions.cpu().detach().numpy(),
|
predictions_np = predictions_uns.cpu().detach().numpy()
|
||||||
self.pred_state[node.type.name],
|
|
||||||
node.type,
|
|
||||||
mean=self.rel_states[node][..., 0:2])
|
|
||||||
|
|
||||||
# Return will be of shape (batch_size, num_samples, num_predicted_timesteps, 2)
|
# Return will be of shape (batch_size, num_samples, num_predicted_timesteps, 2)
|
||||||
return np.transpose(predictions_uns, (1, 0, 2, 3))
|
return prediction_dist, np.transpose(predictions_np, (1, 0, 2, 3))
|
||||||
|
|
||||||
def sample_model(self, num_predicted_timesteps,
|
def sample_model(self, num_predicted_timesteps,
|
||||||
num_samples,
|
num_samples,
|
||||||
|
@ -262,11 +259,12 @@ class OnlineTrajectron(Trajectron):
|
||||||
# No grad since we're predicting always, as evidenced by the line above.
|
# No grad since we're predicting always, as evidenced by the line above.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
predictions_dict = dict()
|
predictions_dict = dict()
|
||||||
|
prediction_dists = dict()
|
||||||
for node in set(self.nodes) - set(self.removed_nodes.keys()):
|
for node in set(self.nodes) - set(self.removed_nodes.keys()):
|
||||||
if node.is_robot:
|
if node.is_robot:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
predictions_dict[node] = self._run_decoder(node, num_predicted_timesteps,
|
prediction_dists[node], predictions_dict[node] = self._run_decoder(node, num_predicted_timesteps,
|
||||||
num_samples,
|
num_samples,
|
||||||
robot_present_and_future,
|
robot_present_and_future,
|
||||||
z_mode,
|
z_mode,
|
||||||
|
@ -274,11 +272,11 @@ class OnlineTrajectron(Trajectron):
|
||||||
full_dist,
|
full_dist,
|
||||||
all_z_sep)
|
all_z_sep)
|
||||||
|
|
||||||
return predictions_dict
|
return prediction_dists, predictions_dict
|
||||||
|
|
||||||
def forward(self, init_env,
|
def forward(self, init_env,
|
||||||
init_timestep,
|
init_timestep,
|
||||||
pos_dicts, # After the initial environment
|
input_dicts, # After the initial environment
|
||||||
num_predicted_timesteps,
|
num_predicted_timesteps,
|
||||||
num_samples,
|
num_samples,
|
||||||
robot_present_and_future=None,
|
robot_present_and_future=None,
|
||||||
|
@ -294,8 +292,8 @@ class OnlineTrajectron(Trajectron):
|
||||||
self.set_environment(init_env, init_timestep)
|
self.set_environment(init_env, init_timestep)
|
||||||
|
|
||||||
# Looping through and applying updates to the model.
|
# Looping through and applying updates to the model.
|
||||||
for i in range(len(pos_dicts)):
|
for i in range(len(input_dicts)):
|
||||||
self.incremental_forward(pos_dicts[i])
|
self.incremental_forward(input_dicts[i])
|
||||||
|
|
||||||
return self.sample_model(num_predicted_timesteps,
|
return self.sample_model(num_predicted_timesteps,
|
||||||
num_samples,
|
num_samples,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import pickle
|
import dill
|
||||||
import random
|
import random
|
||||||
import pathlib
|
import pathlib
|
||||||
import evaluation
|
import evaluation
|
||||||
|
@ -11,7 +11,7 @@ import visualization as vis
|
||||||
from argument_parser import args
|
from argument_parser import args
|
||||||
from model.online.online_trajectron import OnlineTrajectron
|
from model.online.online_trajectron import OnlineTrajectron
|
||||||
from model.model_registrar import ModelRegistrar
|
from model.model_registrar import ModelRegistrar
|
||||||
from data import Environment, Scene
|
from environment import Environment, Scene
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
if not torch.cuda.is_available() or args.device == 'cpu':
|
if not torch.cuda.is_available() or args.device == 'cpu':
|
||||||
|
@ -58,8 +58,56 @@ def create_online_env(env, hyperparams, scene_idx, init_timestep):
|
||||||
robot_type=env.robot_type)
|
robot_type=env.robot_type)
|
||||||
|
|
||||||
|
|
||||||
|
def get_maps_for_input(input_dict, scene, hyperparams):
|
||||||
|
scene_maps = list()
|
||||||
|
scene_pts = list()
|
||||||
|
heading_angles = list()
|
||||||
|
patch_sizes = list()
|
||||||
|
nodes_with_maps = list()
|
||||||
|
for node in input_dict:
|
||||||
|
if node.type in hyperparams['map_encoder']:
|
||||||
|
x = input_dict[node]
|
||||||
|
me_hyp = hyperparams['map_encoder'][node.type]
|
||||||
|
if 'heading_state_index' in me_hyp:
|
||||||
|
heading_state_index = me_hyp['heading_state_index']
|
||||||
|
# We have to rotate the map in the opposit direction of the agent to match them
|
||||||
|
if type(heading_state_index) is list: # infer from velocity or heading vector
|
||||||
|
heading_angle = -np.arctan2(x[-1, heading_state_index[1]],
|
||||||
|
x[-1, heading_state_index[0]]) * 180 / np.pi
|
||||||
|
else:
|
||||||
|
heading_angle = -x[-1, heading_state_index] * 180 / np.pi
|
||||||
|
else:
|
||||||
|
heading_angle = None
|
||||||
|
|
||||||
|
scene_map = scene.map[node.type]
|
||||||
|
map_point = x[-1, :2]
|
||||||
|
|
||||||
|
patch_size = hyperparams['map_encoder'][node.type]['patch_size']
|
||||||
|
|
||||||
|
scene_maps.append(scene_map)
|
||||||
|
scene_pts.append(map_point)
|
||||||
|
heading_angles.append(heading_angle)
|
||||||
|
patch_sizes.append(patch_size)
|
||||||
|
nodes_with_maps.append(node)
|
||||||
|
|
||||||
|
if heading_angles[0] is None:
|
||||||
|
heading_angles = None
|
||||||
|
else:
|
||||||
|
heading_angles = torch.Tensor(heading_angles)
|
||||||
|
|
||||||
|
maps = scene_maps[0].get_cropped_maps_from_scene_map_batch(scene_maps,
|
||||||
|
scene_pts=torch.Tensor(scene_pts),
|
||||||
|
patch_size=patch_sizes[0],
|
||||||
|
rotation=heading_angles)
|
||||||
|
|
||||||
|
maps_dict = {node: maps[[i]] for i, node in enumerate(nodes_with_maps)}
|
||||||
|
return maps_dict
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model_dir = os.path.join(args.log_dir, 'models_14_Jan_2020_00_24_21eth_no_rob')
|
# Choose one of the model directory names under the experiment/*/models folders.
|
||||||
|
# Possibilities are 'vel_ee', 'int_ee', 'int_ee_me', or 'robot'
|
||||||
|
model_dir = os.path.join(args.log_dir, 'int_ee')
|
||||||
|
|
||||||
# Load hyperparameters from json
|
# Load hyperparameters from json
|
||||||
config_file = os.path.join(model_dir, args.conf)
|
config_file = os.path.join(model_dir, args.conf)
|
||||||
|
@ -78,27 +126,22 @@ def main():
|
||||||
hyperparams['k_eval'] = args.k_eval
|
hyperparams['k_eval'] = args.k_eval
|
||||||
hyperparams['offline_scene_graph'] = args.offline_scene_graph
|
hyperparams['offline_scene_graph'] = args.offline_scene_graph
|
||||||
hyperparams['incl_robot_node'] = args.incl_robot_node
|
hyperparams['incl_robot_node'] = args.incl_robot_node
|
||||||
hyperparams['scene_batch_size'] = args.scene_batch_size
|
|
||||||
hyperparams['node_resample_train'] = args.node_resample_train
|
|
||||||
hyperparams['node_resample_eval'] = args.node_resample_eval
|
|
||||||
hyperparams['scene_resample_train'] = args.scene_resample_train
|
|
||||||
hyperparams['scene_resample_eval'] = args.scene_resample_eval
|
|
||||||
hyperparams['scene_resample_viz'] = args.scene_resample_viz
|
|
||||||
hyperparams['edge_encoding'] = not args.no_edge_encoding
|
hyperparams['edge_encoding'] = not args.no_edge_encoding
|
||||||
|
hyperparams['use_map_encoding'] = args.map_encoding
|
||||||
|
|
||||||
output_save_dir = os.path.join(model_dir, 'pred_figs')
|
output_save_dir = os.path.join(model_dir, 'pred_figs')
|
||||||
pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)
|
pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
|
eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
|
||||||
with open(eval_data_path, 'rb') as f:
|
with open(eval_data_path, 'rb') as f:
|
||||||
eval_env = pickle.load(f, encoding='latin1')
|
eval_env = dill.load(f, encoding='latin1')
|
||||||
|
|
||||||
if eval_env.robot_type is None and hyperparams['incl_robot_node']:
|
if eval_env.robot_type is None and hyperparams['incl_robot_node']:
|
||||||
eval_env.robot_type = eval_env.NodeType[0] # TODO: Make more general, allow the user to specify?
|
eval_env.robot_type = eval_env.NodeType[0] # TODO: Make more general, allow the user to specify?
|
||||||
for scene in eval_env.scenes:
|
for scene in eval_env.scenes:
|
||||||
scene.add_robot_from_nodes(eval_env.robot_type)
|
scene.add_robot_from_nodes(eval_env.robot_type)
|
||||||
|
|
||||||
print('Loaded evaluation data from %s' % (eval_data_path,))
|
print('Loaded data from %s' % (eval_data_path,))
|
||||||
|
|
||||||
# Creating a dummy environment with a single scene that contains information about the world.
|
# Creating a dummy environment with a single scene that contains information about the world.
|
||||||
# When using this code, feel free to use whichever scene index or initial timestep you wish.
|
# When using this code, feel free to use whichever scene index or initial timestep you wish.
|
||||||
|
@ -112,7 +155,7 @@ def main():
|
||||||
online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep)
|
online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep)
|
||||||
|
|
||||||
model_registrar = ModelRegistrar(model_dir, args.eval_device)
|
model_registrar = ModelRegistrar(model_dir, args.eval_device)
|
||||||
model_registrar.load_models(iter_num=1999)
|
model_registrar.load_models(iter_num=12)
|
||||||
|
|
||||||
trajectron = OnlineTrajectron(model_registrar,
|
trajectron = OnlineTrajectron(model_registrar,
|
||||||
hyperparams,
|
hyperparams,
|
||||||
|
@ -126,10 +169,14 @@ def main():
|
||||||
trajectron.set_environment(online_env, init_timestep)
|
trajectron.set_environment(online_env, init_timestep)
|
||||||
|
|
||||||
for timestep in range(init_timestep + 1, eval_scene.timesteps):
|
for timestep in range(init_timestep + 1, eval_scene.timesteps):
|
||||||
pos_dict = eval_scene.get_clipped_pos_dict(timestep, hyperparams['state'])
|
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
||||||
|
|
||||||
|
maps = None
|
||||||
|
if hyperparams['use_map_encoding']:
|
||||||
|
maps = get_maps_for_input(input_dict, eval_scene, hyperparams)
|
||||||
|
|
||||||
robot_present_and_future = None
|
robot_present_and_future = None
|
||||||
if eval_scene.robot is not None:
|
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
|
||||||
robot_present_and_future = eval_scene.robot.get(np.array([timestep,
|
robot_present_and_future = eval_scene.robot.get(np.array([timestep,
|
||||||
timestep + hyperparams['prediction_horizon']]),
|
timestep + hyperparams['prediction_horizon']]),
|
||||||
hyperparams['state'][eval_scene.robot.type],
|
hyperparams['state'][eval_scene.robot.type],
|
||||||
|
@ -138,10 +185,12 @@ def main():
|
||||||
# robot_present_and_future += adjustment
|
# robot_present_and_future += adjustment
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
preds = trajectron.incremental_forward(pos_dict,
|
dists, preds = trajectron.incremental_forward(input_dict,
|
||||||
prediction_horizon=12,
|
maps,
|
||||||
num_samples=25,
|
prediction_horizon=6,
|
||||||
robot_present_and_future=robot_present_and_future)
|
num_samples=1,
|
||||||
|
robot_present_and_future=robot_present_and_future,
|
||||||
|
full_dist=True)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
|
print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
|
||||||
1. / (end - start), len(trajectron.nodes),
|
1. / (end - start), len(trajectron.nodes),
|
||||||
|
@ -152,23 +201,16 @@ def main():
|
||||||
if node in preds:
|
if node in preds:
|
||||||
detailed_preds_dict[node] = preds[node]
|
detailed_preds_dict[node] = preds[node]
|
||||||
|
|
||||||
batch_stats = evaluation.compute_batch_statistics({timestep: detailed_preds_dict},
|
|
||||||
eval_scene.dt,
|
|
||||||
max_hl=hyperparams['maximum_history_length'],
|
|
||||||
ph=hyperparams['prediction_horizon'],
|
|
||||||
node_type_enum=online_env.NodeType,
|
|
||||||
prune_ph_to_future=True)
|
|
||||||
|
|
||||||
evaluation.print_batch_errors([batch_stats], 'eval', timestep)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
|
vis.visualize_distribution(ax,
|
||||||
|
dists)
|
||||||
vis.visualize_prediction(ax,
|
vis.visualize_prediction(ax,
|
||||||
{timestep: preds},
|
{timestep: preds},
|
||||||
eval_scene.dt,
|
eval_scene.dt,
|
||||||
hyperparams['maximum_history_length'],
|
hyperparams['maximum_history_length'],
|
||||||
hyperparams['prediction_horizon'])
|
hyperparams['prediction_horizon'])
|
||||||
|
|
||||||
if eval_scene.robot is not None:
|
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
|
||||||
robot_for_plotting = eval_scene.robot.get(np.array([timestep,
|
robot_for_plotting = eval_scene.robot.get(np.array([timestep,
|
||||||
timestep + hyperparams['prediction_horizon']]),
|
timestep + hyperparams['prediction_horizon']]),
|
||||||
hyperparams['state'][eval_scene.robot.type])
|
hyperparams['state'][eval_scene.robot.type])
|
||||||
|
|
|
@ -132,6 +132,9 @@ def main():
|
||||||
return_robot=not args.incl_robot_node)
|
return_robot=not args.incl_robot_node)
|
||||||
train_data_loader = dict()
|
train_data_loader = dict()
|
||||||
for node_type_data_set in train_dataset:
|
for node_type_data_set in train_dataset:
|
||||||
|
if len(node_type_data_set) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
pin_memory=False if args.device is 'cpu' else True,
|
pin_memory=False if args.device is 'cpu' else True,
|
||||||
|
@ -172,6 +175,9 @@ def main():
|
||||||
return_robot=not args.incl_robot_node)
|
return_robot=not args.incl_robot_node)
|
||||||
eval_data_loader = dict()
|
eval_data_loader = dict()
|
||||||
for node_type_data_set in eval_dataset:
|
for node_type_data_set in eval_dataset:
|
||||||
|
if len(node_type_data_set) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
|
||||||
collate_fn=collate,
|
collate_fn=collate,
|
||||||
pin_memory=False if args.eval_device is 'cpu' else True,
|
pin_memory=False if args.eval_device is 'cpu' else True,
|
||||||
|
@ -285,6 +291,7 @@ def main():
|
||||||
predictions = trajectron.predict(scene,
|
predictions = trajectron.predict(scene,
|
||||||
timestep,
|
timestep,
|
||||||
ph,
|
ph,
|
||||||
|
min_future_timesteps=ph,
|
||||||
z_mode=True,
|
z_mode=True,
|
||||||
gmm_mode=True,
|
gmm_mode=True,
|
||||||
all_z_sep=False,
|
all_z_sep=False,
|
||||||
|
|
|
@ -1,2 +1,2 @@
|
||||||
from .visualization import visualize_prediction
|
from .visualization import visualize_prediction, visualize_distribution
|
||||||
from .visualization_utils import plot_boxplots
|
from .visualization_utils import plot_boxplots
|
|
@ -1,5 +1,7 @@
|
||||||
from utils import prediction_output_to_trajectories
|
from utils import prediction_output_to_trajectories
|
||||||
|
from scipy import linalg
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib.patches as patches
|
||||||
import matplotlib.patheffects as pe
|
import matplotlib.patheffects as pe
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
|
@ -87,3 +89,42 @@ def visualize_prediction(ax,
|
||||||
if map is not None:
|
if map is not None:
|
||||||
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
|
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
|
||||||
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs)
|
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_distribution(ax,
|
||||||
|
prediction_distribution_dict,
|
||||||
|
map=None,
|
||||||
|
pi_threshold=0.05,
|
||||||
|
**kwargs):
|
||||||
|
if map is not None:
|
||||||
|
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
|
||||||
|
|
||||||
|
for node, pred_dist in prediction_distribution_dict.items():
|
||||||
|
if pred_dist.mus.shape[:2] != (1, 1):
|
||||||
|
return
|
||||||
|
|
||||||
|
means = pred_dist.mus.squeeze().cpu().numpy()
|
||||||
|
covs = pred_dist.get_covariance_matrix().squeeze().cpu().numpy()
|
||||||
|
pis = pred_dist.pis_cat_dist.probs.squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
for timestep in range(means.shape[0]):
|
||||||
|
for z_val in range(means.shape[1]):
|
||||||
|
mean = means[timestep, z_val]
|
||||||
|
covar = covs[timestep, z_val]
|
||||||
|
pi = pis[timestep, z_val]
|
||||||
|
|
||||||
|
if pi < pi_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
v, w = linalg.eigh(covar)
|
||||||
|
v = 2. * np.sqrt(2.) * np.sqrt(v)
|
||||||
|
u = w[0] / linalg.norm(w[0])
|
||||||
|
|
||||||
|
# Plot an ellipse to show the Gaussian component
|
||||||
|
angle = np.arctan(u[1] / u[0])
|
||||||
|
angle = 180. * angle / np.pi # convert to degrees
|
||||||
|
ell = patches.Ellipse(mean, v[0], v[1], 180. + angle, color='blue' if node.type.name == 'VEHICLE' else 'orange')
|
||||||
|
ell.set_edgecolor(None)
|
||||||
|
ell.set_clip_box(ax.bbox)
|
||||||
|
ell.set_alpha(pi/10)
|
||||||
|
ax.add_artist(ell)
|
||||||
|
|
Loading…
Reference in a new issue