Adding ability to run on streaming data.

This commit is contained in:
BorisIvanovic 2020-12-09 22:42:06 -05:00
parent b99e07dcab
commit 174198bd4f
14 changed files with 305 additions and 125 deletions

View File

@ -130,6 +130,22 @@ If you instead wanted to evaluate a trained model's performance on forecasting p
These scripts will produce csv files in the `results` directory which can then be analyzed in the `NuScenes Quantitative.ipynb` notebook.
## Online Execution ##
As of December 2020, this repository includes an "online" running capability. In addition to the regular batched mode for training and testing, Trajectron++ can now be executed online on streaming data!
The `trajectron/test_online.py` script shows how to use it, and can be run as follows (depending on the desired model).
| Model | Command | File Changes |
|-------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|
| Base | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl | Line 110: `'vel_ee'` |
| +Dynamics Integration | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl | Line 110: `'int_ee'` |
| +Dynamics Integration, Maps | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl --map_encoding | Line 110: `'int_ee_me'` |
| +Dynamics Integration, Maps, Robot Future | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl --map_encoding --incl_robot_node | Line 110: `'robot'` |
Further, lines 145-151 can be changed to choose different scenes and starting timesteps.
During running, each prediction will be iteratively visualized and saved in a `pred_figs/` folder within the specified model folder. For example, if the script loads the `int_ee` version of Trajectron++ then generated figures will be saved to `experiments/nuScenes/models/int_ee/pred_figs/`.
## Datasets ##
### ETH and UCY Pedestrian Datasets ###

View File

@ -46,7 +46,7 @@ class Node(object):
def __repr__(self):
return '/'.join([self.type.name, self.id])
def overwrite_data(self, data, forward_in_time_on_next_overwrite=False):
def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False):
"""
This function hard overwrites the data matrix. When using it you have to make sure that the columns
in the new data matrix correspond to the old structure. As well as setting first_timestep.
@ -55,7 +55,11 @@ class Node(object):
:param forward_in_time_on_next_overwrite: On the !!NEXT!! call of overwrite_data first_timestep will be increased.
:return: None
"""
self.data.data = data
if header is None:
self.data.data = data
else:
self.data = DoubleHeaderNumpyArray(data, header)
self._last_timestep = None
if self.forward_in_time_on_next_override:
self.first_timestep += 1

View File

@ -25,23 +25,28 @@ class Scene(object):
self.non_aug_scene = non_aug_scene
def add_robot_from_nodes(self, robot_type):
nodes_list = [node for node in self.nodes if node.type == robot_type]
non_overlapping_nodes = MultiNode.find_non_overlapping_nodes(nodes_list, min_timesteps=3)
self.robot = MultiNode(robot_type, 'ROBOT', non_overlapping_nodes, is_robot=True)
scenes = [self]
if hasattr(self, 'augmented'):
scenes += self.augmented
for node in non_overlapping_nodes:
self.nodes.remove(node)
self.nodes.append(self.robot)
for scn in scenes:
nodes_list = [node for node in scn.nodes if node.type == robot_type]
non_overlapping_nodes = MultiNode.find_non_overlapping_nodes(nodes_list, min_timesteps=3)
scn.robot = MultiNode(robot_type, 'ROBOT', non_overlapping_nodes, is_robot=True)
def get_clipped_pos_dict(self, timestep, state):
pos_dict = dict()
for node in non_overlapping_nodes:
scn.nodes.remove(node)
scn.nodes.append(scn.robot)
def get_clipped_input_dict(self, timestep, state):
input_dict = dict()
existing_nodes = self.get_nodes_clipped_at_time(timesteps=np.array([timestep]),
state=state)
tr_scene = np.array([timestep, timestep])
for node in existing_nodes:
pos_dict[node] = node.get(tr_scene, {'position': ['x', 'y']})
input_dict[node] = node.get(tr_scene, state[node.type])
return pos_dict
return input_dict
def get_scene_graph(self,
timestep,
@ -160,6 +165,7 @@ class Scene(object):
return clipped_nodes
tr_scene = np.array([timesteps.min(), timesteps.max()])
data_header_memo = dict()
for node in all_nodes:
if isinstance(node, MultiNode):
copied_node = copy.deepcopy(node.get_node_at_timesteps(tr_scene))
@ -168,7 +174,16 @@ class Scene(object):
copied_node = copy.deepcopy(node)
clipped_value = node.get(tr_scene, state[node.type])
copied_node.overwrite_data(clipped_value)
if node.type not in data_header_memo:
data_header = list()
for quantity, values in state[node.type].items():
for value in values:
data_header.append((quantity, value))
data_header_memo[node.type] = data_header
copied_node.overwrite_data(clipped_value, data_header_memo[node.type])
copied_node.first_timestep = tr_scene[0]
clipped_nodes.append(copied_node)

View File

@ -292,7 +292,7 @@ class SceneGraph(object):
other_types = set(node.type for node in other.nodes)
all_node_types = our_types | other_types
new_neighbors = defaultdict(dict)
new_neighbors = defaultdict(lambda: defaultdict(set))
for node in self.nodes:
if node in removed_nodes:
continue
@ -306,9 +306,9 @@ class SceneGraph(object):
for node_type in our_types:
neighbors = self.get_neighbors(node, node_type)
if len(neighbors) > 0:
new_neighbors[node] = {DirectedEdge.get_edge_type(node, Node(node_type, None, None)): set(neighbors)}
new_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = set(neighbors)
removed_neighbors = defaultdict(dict)
removed_neighbors = defaultdict(lambda: defaultdict(set))
for node in other.nodes:
if node in removed_nodes:
continue
@ -322,13 +322,13 @@ class SceneGraph(object):
for node_type in other_types:
neighbors = other.get_neighbors(node, node_type)
if len(neighbors) > 0:
removed_neighbors[node] = {DirectedEdge.get_edge_type(node, Node(node_type, None, None)): set(neighbors)}
removed_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = set(neighbors)
return new_nodes, removed_nodes, new_neighbors, removed_neighbors
if __name__ == '__main__':
from data import NodeTypeEnum
from environment import NodeTypeEnum
import time
# # # # # # # # # # # # # # # # #

View File

@ -1,2 +1,2 @@
from .dataset import EnvironmentDataset, NodeTypeDataset
from .preprocessing import collate, get_node_timestep_data, get_timesteps_data, restore
from .preprocessing import collate, get_node_timestep_data, get_timesteps_data, restore, get_relative_robot_traj

View File

@ -159,8 +159,9 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state,
else:
robot = scene.robot
robot_type = robot.type
robot_traj = robot.get(timestep_range_r, state[robot_type], padding=0.0)
robot_traj = robot.get(timestep_range_r, state[robot_type], padding=np.nan)
robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type)
robot_traj_st_t[torch.isnan(robot_traj_st_t)] = 0.0
# Map
map_tuple = None

View File

@ -62,6 +62,12 @@ class Unicycle(Dynamic):
ph = control_samples.shape[-2]
p_0 = self.initial_conditions['pos'].unsqueeze(1)
v_0 = self.initial_conditions['vel'].unsqueeze(1)
# In case the input is batched because of the robot in online use we repeat this to match the batch size of x.
if p_0.size()[0] != x.size()[0]:
p_0 = p_0.repeat(x.size()[0], 1, 1)
v_0 = v_0.repeat(x.size()[0], 1, 1)
phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0])
phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1)))
@ -193,6 +199,12 @@ class Unicycle(Dynamic):
ph = control_dist_dphi_a.mus.shape[-3]
p_0 = self.initial_conditions['pos'].unsqueeze(1)
v_0 = self.initial_conditions['vel'].unsqueeze(1)
# In case the input is batched because of the robot in online use we repeat this to match the batch size of x.
if p_0.size()[0] != x.size()[0]:
p_0 = p_0.repeat(x.size()[0], 1, 1)
v_0 = v_0.repeat(x.size()[0], 1, 1)
phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0])
phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1)))

View File

@ -367,10 +367,11 @@ class MultimodalGenerativeCVAE(object):
neighbors_edge_value,
robot,
map) -> (torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor):
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor):
"""
Encodes input and output tensors for node and robot.
@ -471,7 +472,7 @@ class MultimodalGenerativeCVAE(object):
if self.log_writer and (self.curr_iter + 1) % 500 == 0:
map_clone = map.clone()
map_patch = self.hyperparams['map_encoder'][self.node_type]['patch_size']
map_clone[:, :, map_patch[1]-5:map_patch[1]+5, map_patch[0]-5:map_patch[0]+5] = 1.
map_clone[:, :, map_patch[1] - 5:map_patch[1] + 5, map_patch[0] - 5:map_patch[0] + 5] = 1.
self.log_writer.add_images(f"{self.node_type}/cropped_maps", map_clone,
self.curr_iter, dataformats='NCWH')
@ -812,10 +813,10 @@ class MultimodalGenerativeCVAE(object):
state = initial_state
if self.hyperparams['incl_robot_node']:
input_ = torch.cat([zx,
a_0.repeat(num_samples*num_components, 1),
x_nr_t.repeat(num_samples*num_components, 1)], dim=1)
a_0.repeat(num_samples * num_components, 1),
x_nr_t.repeat(num_samples * num_components, 1)], dim=1)
else:
input_ = torch.cat([zx, a_0.repeat(num_samples*num_components, 1)], dim=1)
input_ = torch.cat([zx, a_0.repeat(num_samples * num_components, 1)], dim=1)
for j in range(ph):
h_state = cell(input_, state)
@ -989,7 +990,7 @@ class MultimodalGenerativeCVAE(object):
z, kl = self.encoder(mode, x, y_e)
log_p_y_xz = self.decoder(mode, x, x_nr_t, y, y_r, n_s_t0, z,
labels, # Loss is calculated on unstandardized label
labels, # Loss is calculated on unstandardized label
prediction_horizon,
self.hyperparams['k'])

View File

@ -6,6 +6,7 @@ import numpy as np
from collections import defaultdict, Counter
from model.components import *
from model.model_utils import *
from model.dataset import get_relative_robot_traj
import model.dynamics as dynamic_module
from model.mgcvae import MultimodalGenerativeCVAE
from environment.scene_graph import DirectedEdge
@ -48,7 +49,8 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
dynamic_class = getattr(dynamic_module, self.hyperparams['dynamic'][self.node_type]['name'])
dyn_limits = hyperparams['dynamic'][self.node_type]['limits']
self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device, self.model_registrar, self.x_size)
self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device,
self.model_registrar, self.x_size, self.node_type)
def create_graphical_model(self):
"""
@ -127,11 +129,13 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
if len(self.scene_graph.get_neighbors(self.node, self._get_other_node_type_from_edge(edge_type))) == 0:
del self.node_modules[edge_type + '/edge_encoder']
def obtain_encoded_tensors(self, mode, inputs, inputs_st, inputs_np, robot_present_and_future) -> (torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor):
def obtain_encoded_tensors(self,
mode,
inputs,
inputs_st,
inputs_np,
robot_present_and_future,
maps):
x, x_r_t, y_r = None, None, None
batch_size = 1
@ -139,15 +143,19 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
our_inputs_st = inputs_st[self.node]
initial_dynamics = dict()
initial_dynamics['pos'] = our_inputs_st[:, 0:2] # TODO: Generalize
initial_dynamics['vel'] = our_inputs_st[:, 2:4] # TODO: Generalize
initial_dynamics['pos'] = our_inputs[:, 0:2] # TODO: Generalize
initial_dynamics['vel'] = our_inputs[:, 2:4] # TODO: Generalize
self.dynamic.set_initial_condition(initial_dynamics)
#########################################
# Provide basic information to encoders #
#########################################
if self.hyperparams['incl_robot_node'] and self.robot is not None:
x_r_t, y_r = self.get_relative_robot_traj(our_inputs, robot_present_and_future, self.robot.type)
robot_present_and_future_st = get_relative_robot_traj(self.env, self.state,
our_inputs, robot_present_and_future,
self.node.type, self.robot.type)
x_r_t = robot_present_and_future_st[..., 0, :]
y_r = robot_present_and_future_st[..., 1:, :]
##################
# Encode History #
@ -194,6 +202,21 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
self.TD = {'node_history_encoded': node_history_encoded,
'total_edge_influence': total_edge_influence}
################
# Map Encoding #
################
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
if self.node not in maps:
# This means the node was removed (it is only being kept around because of the edge removal filter).
me_params = self.hyperparams['map_encoder'][self.node_type]
self.TD['encoded_map'] = torch.zeros((1, me_params['output_size']))
else:
encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1.,
(mode == ModeKeys.TRAIN))
do = self.hyperparams['map_encoder'][self.node_type]['dropout']
encoded_map = F.dropout(encoded_map, do, training=(mode == ModeKeys.TRAIN))
self.TD['encoded_map'] = encoded_map
######################################
# Concatenate Encoder Outputs into x #
######################################
@ -207,6 +230,8 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
node_history_encoded = TD['node_history_encoded']
if self.hyperparams['edge_encoding']:
total_edge_influence = TD['total_edge_influence']
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
encoded_map = TD['encoded_map']
if (self.hyperparams['incl_robot_node']
and self.robot is not None
@ -222,6 +247,8 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
node_history_encoded = TD['node_history_encoded'].repeat(robot_future_st.size()[0], 1)
if self.hyperparams['edge_encoding']:
total_edge_influence = TD['total_edge_influence'].repeat(robot_future_st.size()[0], 1)
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
encoded_map = TD['encoded_map'].repeat(robot_future_st.size()[0], 1)
elif self.hyperparams['incl_robot_node'] and self.robot is not None:
# Four times because we're trying to mimic a bi-directional RNN's output (which is c and h from both ends).
@ -239,6 +266,9 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
if self.hyperparams['incl_robot_node'] and self.robot is not None:
x_concat_list.append(robot_future_encoder) # [bs/nbs, 4*enc_rnn_dim_history]
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
x_concat_list.append(encoded_map) # [bs/nbs, CNN output size]
return torch.cat(x_concat_list, dim=1)
def encode_node_history(self, inputs_st):
@ -258,13 +288,19 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
edge_states_list = list() # list of [#of neighbors, max_ht, state_dim]
neighbor_states = list()
rel_state = inputs[self.node].cpu().numpy()
orig_rel_state = inputs[self.node].cpu().numpy()
for node in connected_nodes[0]:
neighbor_state_np = inputs_np[node]
# Make State relative to node
_, std = self.env.get_standardize_params(self.state[node.type], node_type=node.type)
std[0:2] = self.env.attention_radius[edge_type_tuple]
# TODO: This all makes the unsafe assumption that the first n dims
# refer to the same quantities even for different agent types!
equal_dims = np.min((neighbor_state_np.shape[-1], orig_rel_state.shape[-1]))
rel_state = np.zeros_like(neighbor_state_np)
rel_state[..., :equal_dims] = orig_rel_state[..., :equal_dims]
neighbor_state_np_st = self.env.standardize(neighbor_state_np,
self.state[node.type],
node_type=node.type,
@ -333,7 +369,7 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
else:
return outputs[:, 0, :] # [bs, enc_rnn_dim]
def encoder_forward(self, inputs, inputs_st, inputs_np, robot_present_and_future=None):
def encoder_forward(self, inputs, inputs_st, inputs_np, robot_present_and_future=None, maps=None):
# Always predicting with the online model.
mode = ModeKeys.PREDICT
@ -341,7 +377,9 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
inputs,
inputs_st,
inputs_np,
robot_present_and_future)
robot_present_and_future,
maps)
self.n_s_t0 = inputs_st[self.node]
self.latent.p_dist = self.p_z_x(mode, self.x)
@ -361,27 +399,32 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
if (self.hyperparams['incl_robot_node']
and self.robot is not None
and robot_present_and_future is not None):
x_nr_t, y_r = self.get_relative_robot_traj(
torch.tensor(self.node.get(np.array([self.node.last_timestep]),
self.state[self.node.type],
padding=0.0),
dtype=torch.float,
device=self.device),
robot_present_and_future,
self.robot.type)
our_inputs = torch.tensor(self.node.get(np.array([self.node.last_timestep]),
self.state[self.node.type],
padding=0.0),
dtype=torch.float,
device=self.device)
robot_present_and_future_st = get_relative_robot_traj(self.env, self.state,
our_inputs, robot_present_and_future,
self.node.type, self.robot.type)
x_nr_t = robot_present_and_future_st[..., 0, :]
y_r = robot_present_and_future_st[..., 1:, :]
self.x = self.create_encoder_rep(mode, self.TD, x_nr_t, y_r)
self.latent.p_dist = self.p_z_x(mode, self.x)
# Making sure n_s_t0 has the same batch size as x_nr_t
self.n_s_t0 = self.n_s_t0[[0]].repeat(x_nr_t.size()[0], 1)
z, num_samples, num_components = self.latent.sample_p(num_samples,
mode,
most_likely_z=z_mode,
full_dist=full_dist,
all_z_sep=all_z_sep)
_, our_sampled_future = self.p_y_xz(mode, self.x, x_nr_t, y_r, z,
prediction_horizon,
num_samples,
num_components,
gmm_mode)
y_dist, our_sampled_future = self.p_y_xz(mode, self.x, x_nr_t, y_r, self.n_s_t0, z,
prediction_horizon,
num_samples,
num_components,
gmm_mode)
return our_sampled_future
return y_dist, our_sampled_future

View File

@ -4,7 +4,7 @@ from collections import Counter
from model.trajectron import Trajectron
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
from model.model_utils import ModeKeys
from data import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
class OnlineTrajectron(Trajectron):
@ -58,10 +58,11 @@ class OnlineTrajectron(Trajectron):
# Fast-forwarding ourselves to the initial timestep, without running any of the underlying models.
for timestep in range(init_timestep + 1):
self.incremental_forward(self.env.scenes[0].get_clipped_pos_dict(timestep, self.hyperparams['state']),
run_models=False)
self.incremental_forward(self.env.scenes[0].get_clipped_input_dict(timestep, self.hyperparams['state']),
maps=None, run_models=False)
def incremental_forward(self, new_inputs_dict,
maps,
prediction_horizon=0,
num_samples=0,
robot_present_and_future=None,
@ -80,7 +81,8 @@ class OnlineTrajectron(Trajectron):
for node, new_input in new_inputs_dict.items():
if node not in self.node_data:
self.node_data[node] = RingBuffer(capacity=self.RING_CAPACITY,
dtype=(float, len(self.state[node.type]['position'])))
dtype=(float, sum(len(self.state[node.type][k])
for k in self.state[node.type])))
self.node_data[node].append(new_input)
if node in self.removed_nodes:
@ -101,17 +103,10 @@ class OnlineTrajectron(Trajectron):
# that when it's passed through the LSTMs, the hidden state keeps propagating but the input plays no role
# (the NaNs get converted to zeros later on).
for node in self.removed_nodes:
self.node_data[node].append(np.full((1, 2), np.nan))
self.node_data[node].append(np.full((1, self.node_data[node].shape[1]), np.nan))
for node in self.node_data:
x = self.node_data[node][:, 0]
y = self.node_data[node][:, 1]
vx = derivative_of(x, self.env.scenes[0].dt)
vy = derivative_of(y, self.env.scenes[0].dt)
ax = derivative_of(vx, self.env.scenes[0].dt)
ay = derivative_of(vy, self.env.scenes[0].dt)
new_node_data = np.stack([x, y, vx, vy, ax, ay], axis=-1)
node.overwrite_data(new_node_data,
node.overwrite_data(self.node_data[node], None,
forward_in_time_on_next_overwrite=(self.node_data[node].shape[0]
== self.RING_CAPACITY))
@ -139,14 +134,15 @@ class OnlineTrajectron(Trajectron):
# These next 2 for loops add or remove entire node models.
for node in new_nodes:
if node.is_robot:
if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys():
# Only deal with Models for NodeTypes we want to predict
continue
self._add_node_model(node)
self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors)
for node in removed_nodes:
if node.is_robot:
if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys():
continue
self._remove_node_model(node)
@ -157,7 +153,8 @@ class OnlineTrajectron(Trajectron):
inputs_st = dict()
inputs_np = dict()
iter_list = list(self.node_models_dict.keys())
iter_list = list(self.node_models_dict.keys()) + [node for node in new_inputs_dict
if node.type not in self.pred_state.keys()]
if self.env.scenes[0].robot is not None:
iter_list.append(self.env.scenes[0].robot)
@ -191,12 +188,15 @@ class OnlineTrajectron(Trajectron):
robot_present_and_future = robot_present_and_future[np.newaxis, :]
assert robot_present_and_future.shape[1] == prediction_horizon + 1
robot_present_and_future = torch.tensor(robot_present_and_future,
dtype=torch.float, device=self.device)
for node in self.node_models_dict:
self.node_models_dict[node].encoder_forward(inputs,
inputs_st,
inputs_np,
robot_present_and_future)
robot_present_and_future,
maps)
# If num_predicted_timesteps or num_samples == 0 then do not run the decoder at all,
# just update the encoder LSTMs.
@ -220,21 +220,18 @@ class OnlineTrajectron(Trajectron):
full_dist=False,
all_z_sep=False):
model = self.node_models_dict[node]
predictions = model.decoder_forward(num_predicted_timesteps,
num_samples,
robot_present_and_future=robot_present_and_future,
z_mode=z_mode,
gmm_mode=gmm_mode,
full_dist=full_dist,
all_z_sep=all_z_sep)
prediction_dist, predictions_uns = model.decoder_forward(num_predicted_timesteps,
num_samples,
robot_present_and_future=robot_present_and_future,
z_mode=z_mode,
gmm_mode=gmm_mode,
full_dist=full_dist,
all_z_sep=all_z_sep)
predictions_uns = self.env.unstandardize(predictions.cpu().detach().numpy(),
self.pred_state[node.type.name],
node.type,
mean=self.rel_states[node][..., 0:2])
predictions_np = predictions_uns.cpu().detach().numpy()
# Return will be of shape (batch_size, num_samples, num_predicted_timesteps, 2)
return np.transpose(predictions_uns, (1, 0, 2, 3))
return prediction_dist, np.transpose(predictions_np, (1, 0, 2, 3))
def sample_model(self, num_predicted_timesteps,
num_samples,
@ -262,23 +259,24 @@ class OnlineTrajectron(Trajectron):
# No grad since we're predicting always, as evidenced by the line above.
with torch.no_grad():
predictions_dict = dict()
prediction_dists = dict()
for node in set(self.nodes) - set(self.removed_nodes.keys()):
if node.is_robot:
continue
predictions_dict[node] = self._run_decoder(node, num_predicted_timesteps,
num_samples,
robot_present_and_future,
z_mode,
gmm_mode,
full_dist,
all_z_sep)
prediction_dists[node], predictions_dict[node] = self._run_decoder(node, num_predicted_timesteps,
num_samples,
robot_present_and_future,
z_mode,
gmm_mode,
full_dist,
all_z_sep)
return predictions_dict
return prediction_dists, predictions_dict
def forward(self, init_env,
init_timestep,
pos_dicts, # After the initial environment
input_dicts, # After the initial environment
num_predicted_timesteps,
num_samples,
robot_present_and_future=None,
@ -294,8 +292,8 @@ class OnlineTrajectron(Trajectron):
self.set_environment(init_env, init_timestep)
# Looping through and applying updates to the model.
for i in range(len(pos_dicts)):
self.incremental_forward(pos_dicts[i])
for i in range(len(input_dicts)):
self.incremental_forward(input_dicts[i])
return self.sample_model(num_predicted_timesteps,
num_samples,

View File

@ -2,7 +2,7 @@ import os
import time
import json
import torch
import pickle
import dill
import random
import pathlib
import evaluation
@ -11,7 +11,7 @@ import visualization as vis
from argument_parser import args
from model.online.online_trajectron import OnlineTrajectron
from model.model_registrar import ModelRegistrar
from data import Environment, Scene
from environment import Environment, Scene
import matplotlib.pyplot as plt
if not torch.cuda.is_available() or args.device == 'cpu':
@ -58,8 +58,56 @@ def create_online_env(env, hyperparams, scene_idx, init_timestep):
robot_type=env.robot_type)
def get_maps_for_input(input_dict, scene, hyperparams):
scene_maps = list()
scene_pts = list()
heading_angles = list()
patch_sizes = list()
nodes_with_maps = list()
for node in input_dict:
if node.type in hyperparams['map_encoder']:
x = input_dict[node]
me_hyp = hyperparams['map_encoder'][node.type]
if 'heading_state_index' in me_hyp:
heading_state_index = me_hyp['heading_state_index']
# We have to rotate the map in the opposit direction of the agent to match them
if type(heading_state_index) is list: # infer from velocity or heading vector
heading_angle = -np.arctan2(x[-1, heading_state_index[1]],
x[-1, heading_state_index[0]]) * 180 / np.pi
else:
heading_angle = -x[-1, heading_state_index] * 180 / np.pi
else:
heading_angle = None
scene_map = scene.map[node.type]
map_point = x[-1, :2]
patch_size = hyperparams['map_encoder'][node.type]['patch_size']
scene_maps.append(scene_map)
scene_pts.append(map_point)
heading_angles.append(heading_angle)
patch_sizes.append(patch_size)
nodes_with_maps.append(node)
if heading_angles[0] is None:
heading_angles = None
else:
heading_angles = torch.Tensor(heading_angles)
maps = scene_maps[0].get_cropped_maps_from_scene_map_batch(scene_maps,
scene_pts=torch.Tensor(scene_pts),
patch_size=patch_sizes[0],
rotation=heading_angles)
maps_dict = {node: maps[[i]] for i, node in enumerate(nodes_with_maps)}
return maps_dict
def main():
model_dir = os.path.join(args.log_dir, 'models_14_Jan_2020_00_24_21eth_no_rob')
# Choose one of the model directory names under the experiment/*/models folders.
# Possibilities are 'vel_ee', 'int_ee', 'int_ee_me', or 'robot'
model_dir = os.path.join(args.log_dir, 'int_ee')
# Load hyperparameters from json
config_file = os.path.join(model_dir, args.conf)
@ -78,27 +126,22 @@ def main():
hyperparams['k_eval'] = args.k_eval
hyperparams['offline_scene_graph'] = args.offline_scene_graph
hyperparams['incl_robot_node'] = args.incl_robot_node
hyperparams['scene_batch_size'] = args.scene_batch_size
hyperparams['node_resample_train'] = args.node_resample_train
hyperparams['node_resample_eval'] = args.node_resample_eval
hyperparams['scene_resample_train'] = args.scene_resample_train
hyperparams['scene_resample_eval'] = args.scene_resample_eval
hyperparams['scene_resample_viz'] = args.scene_resample_viz
hyperparams['edge_encoding'] = not args.no_edge_encoding
hyperparams['use_map_encoding'] = args.map_encoding
output_save_dir = os.path.join(model_dir, 'pred_figs')
pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)
eval_data_path = os.path.join(args.data_dir, args.eval_data_dict)
with open(eval_data_path, 'rb') as f:
eval_env = pickle.load(f, encoding='latin1')
eval_env = dill.load(f, encoding='latin1')
if eval_env.robot_type is None and hyperparams['incl_robot_node']:
eval_env.robot_type = eval_env.NodeType[0] # TODO: Make more general, allow the user to specify?
for scene in eval_env.scenes:
scene.add_robot_from_nodes(eval_env.robot_type)
print('Loaded evaluation data from %s' % (eval_data_path,))
print('Loaded data from %s' % (eval_data_path,))
# Creating a dummy environment with a single scene that contains information about the world.
# When using this code, feel free to use whichever scene index or initial timestep you wish.
@ -112,7 +155,7 @@ def main():
online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep)
model_registrar = ModelRegistrar(model_dir, args.eval_device)
model_registrar.load_models(iter_num=1999)
model_registrar.load_models(iter_num=12)
trajectron = OnlineTrajectron(model_registrar,
hyperparams,
@ -126,10 +169,14 @@ def main():
trajectron.set_environment(online_env, init_timestep)
for timestep in range(init_timestep + 1, eval_scene.timesteps):
pos_dict = eval_scene.get_clipped_pos_dict(timestep, hyperparams['state'])
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
maps = None
if hyperparams['use_map_encoding']:
maps = get_maps_for_input(input_dict, eval_scene, hyperparams)
robot_present_and_future = None
if eval_scene.robot is not None:
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
robot_present_and_future = eval_scene.robot.get(np.array([timestep,
timestep + hyperparams['prediction_horizon']]),
hyperparams['state'][eval_scene.robot.type],
@ -138,10 +185,12 @@ def main():
# robot_present_and_future += adjustment
start = time.time()
preds = trajectron.incremental_forward(pos_dict,
prediction_horizon=12,
num_samples=25,
robot_present_and_future=robot_present_and_future)
dists, preds = trajectron.incremental_forward(input_dict,
maps,
prediction_horizon=6,
num_samples=1,
robot_present_and_future=robot_present_and_future,
full_dist=True)
end = time.time()
print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
1. / (end - start), len(trajectron.nodes),
@ -152,23 +201,16 @@ def main():
if node in preds:
detailed_preds_dict[node] = preds[node]
batch_stats = evaluation.compute_batch_statistics({timestep: detailed_preds_dict},
eval_scene.dt,
max_hl=hyperparams['maximum_history_length'],
ph=hyperparams['prediction_horizon'],
node_type_enum=online_env.NodeType,
prune_ph_to_future=True)
evaluation.print_batch_errors([batch_stats], 'eval', timestep)
fig, ax = plt.subplots()
vis.visualize_distribution(ax,
dists)
vis.visualize_prediction(ax,
{timestep: preds},
eval_scene.dt,
hyperparams['maximum_history_length'],
hyperparams['prediction_horizon'])
if eval_scene.robot is not None:
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
robot_for_plotting = eval_scene.robot.get(np.array([timestep,
timestep + hyperparams['prediction_horizon']]),
hyperparams['state'][eval_scene.robot.type])

View File

@ -132,6 +132,9 @@ def main():
return_robot=not args.incl_robot_node)
train_data_loader = dict()
for node_type_data_set in train_dataset:
if len(node_type_data_set) == 0:
continue
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
collate_fn=collate,
pin_memory=False if args.device is 'cpu' else True,
@ -172,6 +175,9 @@ def main():
return_robot=not args.incl_robot_node)
eval_data_loader = dict()
for node_type_data_set in eval_dataset:
if len(node_type_data_set) == 0:
continue
node_type_dataloader = utils.data.DataLoader(node_type_data_set,
collate_fn=collate,
pin_memory=False if args.eval_device is 'cpu' else True,
@ -285,6 +291,7 @@ def main():
predictions = trajectron.predict(scene,
timestep,
ph,
min_future_timesteps=ph,
z_mode=True,
gmm_mode=True,
all_z_sep=False,

View File

@ -1,2 +1,2 @@
from .visualization import visualize_prediction
from .visualization import visualize_prediction, visualize_distribution
from .visualization_utils import plot_boxplots

View File

@ -1,5 +1,7 @@
from utils import prediction_output_to_trajectories
from scipy import linalg
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.patheffects as pe
import numpy as np
import seaborn as sns
@ -86,4 +88,43 @@ def visualize_prediction(ax,
if map is not None:
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs)
plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs)
def visualize_distribution(ax,
prediction_distribution_dict,
map=None,
pi_threshold=0.05,
**kwargs):
if map is not None:
ax.imshow(map.as_image(), origin='lower', alpha=0.5)
for node, pred_dist in prediction_distribution_dict.items():
if pred_dist.mus.shape[:2] != (1, 1):
return
means = pred_dist.mus.squeeze().cpu().numpy()
covs = pred_dist.get_covariance_matrix().squeeze().cpu().numpy()
pis = pred_dist.pis_cat_dist.probs.squeeze().cpu().numpy()
for timestep in range(means.shape[0]):
for z_val in range(means.shape[1]):
mean = means[timestep, z_val]
covar = covs[timestep, z_val]
pi = pis[timestep, z_val]
if pi < pi_threshold:
continue
v, w = linalg.eigh(covar)
v = 2. * np.sqrt(2.) * np.sqrt(v)
u = w[0] / linalg.norm(w[0])
# Plot an ellipse to show the Gaussian component
angle = np.arctan(u[1] / u[0])
angle = 180. * angle / np.pi # convert to degrees
ell = patches.Ellipse(mean, v[0], v[1], 180. + angle, color='blue' if node.type.name == 'VEHICLE' else 'orange')
ell.set_edgecolor(None)
ell.set_clip_box(ax.bbox)
ell.set_alpha(pi/10)
ax.add_artist(ell)