From 174198bd4f2119c4d9b440de482b64e3b8ad1719 Mon Sep 17 00:00:00 2001 From: BorisIvanovic Date: Wed, 9 Dec 2020 22:42:06 -0500 Subject: [PATCH] Adding ability to run on streaming data. --- README.md | 16 ++++ trajectron/environment/node.py | 8 +- trajectron/environment/scene.py | 37 +++++--- trajectron/environment/scene_graph.py | 10 +- trajectron/model/dataset/__init__.py | 2 +- trajectron/model/dataset/preprocessing.py | 3 +- trajectron/model/dynamics/unicycle.py | 12 +++ trajectron/model/mgcvae.py | 19 ++-- trajectron/model/online/online_mgcvae.py | 95 +++++++++++++------ trajectron/model/online/online_trajectron.py | 78 ++++++++-------- trajectron/test_online.py | 98 ++++++++++++++------ trajectron/train.py | 7 ++ trajectron/visualization/__init__.py | 2 +- trajectron/visualization/visualization.py | 43 ++++++++- 14 files changed, 305 insertions(+), 125 deletions(-) diff --git a/README.md b/README.md index b5c6960..42144a3 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,22 @@ If you instead wanted to evaluate a trained model's performance on forecasting p These scripts will produce csv files in the `results` directory which can then be analyzed in the `NuScenes Quantitative.ipynb` notebook. +## Online Execution ## +As of December 2020, this repository includes an "online" running capability. In addition to the regular batched mode for training and testing, Trajectron++ can now be executed online on streaming data! + +The `trajectron/test_online.py` script shows how to use it, and can be run as follows (depending on the desired model). + +| Model | Command | File Changes | +|-------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------| +| Base | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl | Line 110: `'vel_ee'` | +| +Dynamics Integration | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl | Line 110: `'int_ee'` | +| +Dynamics Integration, Maps | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl --map_encoding | Line 110: `'int_ee_me'` | +| +Dynamics Integration, Maps, Robot Future | python test_online.py --log_dir=../experiments/nuScenes/models --data_dir=../experiments/processed --conf=config.json --eval_data_dict=nuScenes_test_mini_full.pkl --map_encoding --incl_robot_node | Line 110: `'robot'` | + +Further, lines 145-151 can be changed to choose different scenes and starting timesteps. + +During running, each prediction will be iteratively visualized and saved in a `pred_figs/` folder within the specified model folder. For example, if the script loads the `int_ee` version of Trajectron++ then generated figures will be saved to `experiments/nuScenes/models/int_ee/pred_figs/`. + ## Datasets ## ### ETH and UCY Pedestrian Datasets ### diff --git a/trajectron/environment/node.py b/trajectron/environment/node.py index bd128af..2db820e 100644 --- a/trajectron/environment/node.py +++ b/trajectron/environment/node.py @@ -46,7 +46,7 @@ class Node(object): def __repr__(self): return '/'.join([self.type.name, self.id]) - def overwrite_data(self, data, forward_in_time_on_next_overwrite=False): + def overwrite_data(self, data, header, forward_in_time_on_next_overwrite=False): """ This function hard overwrites the data matrix. When using it you have to make sure that the columns in the new data matrix correspond to the old structure. As well as setting first_timestep. @@ -55,7 +55,11 @@ class Node(object): :param forward_in_time_on_next_overwrite: On the !!NEXT!! call of overwrite_data first_timestep will be increased. :return: None """ - self.data.data = data + if header is None: + self.data.data = data + else: + self.data = DoubleHeaderNumpyArray(data, header) + self._last_timestep = None if self.forward_in_time_on_next_override: self.first_timestep += 1 diff --git a/trajectron/environment/scene.py b/trajectron/environment/scene.py index 7a5c60b..2991482 100644 --- a/trajectron/environment/scene.py +++ b/trajectron/environment/scene.py @@ -25,23 +25,28 @@ class Scene(object): self.non_aug_scene = non_aug_scene def add_robot_from_nodes(self, robot_type): - nodes_list = [node for node in self.nodes if node.type == robot_type] - non_overlapping_nodes = MultiNode.find_non_overlapping_nodes(nodes_list, min_timesteps=3) - self.robot = MultiNode(robot_type, 'ROBOT', non_overlapping_nodes, is_robot=True) + scenes = [self] + if hasattr(self, 'augmented'): + scenes += self.augmented - for node in non_overlapping_nodes: - self.nodes.remove(node) - self.nodes.append(self.robot) + for scn in scenes: + nodes_list = [node for node in scn.nodes if node.type == robot_type] + non_overlapping_nodes = MultiNode.find_non_overlapping_nodes(nodes_list, min_timesteps=3) + scn.robot = MultiNode(robot_type, 'ROBOT', non_overlapping_nodes, is_robot=True) - def get_clipped_pos_dict(self, timestep, state): - pos_dict = dict() + for node in non_overlapping_nodes: + scn.nodes.remove(node) + scn.nodes.append(scn.robot) + + def get_clipped_input_dict(self, timestep, state): + input_dict = dict() existing_nodes = self.get_nodes_clipped_at_time(timesteps=np.array([timestep]), state=state) tr_scene = np.array([timestep, timestep]) for node in existing_nodes: - pos_dict[node] = node.get(tr_scene, {'position': ['x', 'y']}) + input_dict[node] = node.get(tr_scene, state[node.type]) - return pos_dict + return input_dict def get_scene_graph(self, timestep, @@ -160,6 +165,7 @@ class Scene(object): return clipped_nodes tr_scene = np.array([timesteps.min(), timesteps.max()]) + data_header_memo = dict() for node in all_nodes: if isinstance(node, MultiNode): copied_node = copy.deepcopy(node.get_node_at_timesteps(tr_scene)) @@ -168,7 +174,16 @@ class Scene(object): copied_node = copy.deepcopy(node) clipped_value = node.get(tr_scene, state[node.type]) - copied_node.overwrite_data(clipped_value) + + if node.type not in data_header_memo: + data_header = list() + for quantity, values in state[node.type].items(): + for value in values: + data_header.append((quantity, value)) + + data_header_memo[node.type] = data_header + + copied_node.overwrite_data(clipped_value, data_header_memo[node.type]) copied_node.first_timestep = tr_scene[0] clipped_nodes.append(copied_node) diff --git a/trajectron/environment/scene_graph.py b/trajectron/environment/scene_graph.py index 3806d4d..1113bd4 100644 --- a/trajectron/environment/scene_graph.py +++ b/trajectron/environment/scene_graph.py @@ -292,7 +292,7 @@ class SceneGraph(object): other_types = set(node.type for node in other.nodes) all_node_types = our_types | other_types - new_neighbors = defaultdict(dict) + new_neighbors = defaultdict(lambda: defaultdict(set)) for node in self.nodes: if node in removed_nodes: continue @@ -306,9 +306,9 @@ class SceneGraph(object): for node_type in our_types: neighbors = self.get_neighbors(node, node_type) if len(neighbors) > 0: - new_neighbors[node] = {DirectedEdge.get_edge_type(node, Node(node_type, None, None)): set(neighbors)} + new_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = set(neighbors) - removed_neighbors = defaultdict(dict) + removed_neighbors = defaultdict(lambda: defaultdict(set)) for node in other.nodes: if node in removed_nodes: continue @@ -322,13 +322,13 @@ class SceneGraph(object): for node_type in other_types: neighbors = other.get_neighbors(node, node_type) if len(neighbors) > 0: - removed_neighbors[node] = {DirectedEdge.get_edge_type(node, Node(node_type, None, None)): set(neighbors)} + removed_neighbors[node][DirectedEdge.get_edge_type(node, Node(node_type, None, None))] = set(neighbors) return new_nodes, removed_nodes, new_neighbors, removed_neighbors if __name__ == '__main__': - from data import NodeTypeEnum + from environment import NodeTypeEnum import time # # # # # # # # # # # # # # # # # diff --git a/trajectron/model/dataset/__init__.py b/trajectron/model/dataset/__init__.py index a18c16c..a01f88e 100644 --- a/trajectron/model/dataset/__init__.py +++ b/trajectron/model/dataset/__init__.py @@ -1,2 +1,2 @@ from .dataset import EnvironmentDataset, NodeTypeDataset -from .preprocessing import collate, get_node_timestep_data, get_timesteps_data, restore +from .preprocessing import collate, get_node_timestep_data, get_timesteps_data, restore, get_relative_robot_traj diff --git a/trajectron/model/dataset/preprocessing.py b/trajectron/model/dataset/preprocessing.py index e36f37a..c8a0e8f 100644 --- a/trajectron/model/dataset/preprocessing.py +++ b/trajectron/model/dataset/preprocessing.py @@ -159,8 +159,9 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, else: robot = scene.robot robot_type = robot.type - robot_traj = robot.get(timestep_range_r, state[robot_type], padding=0.0) + robot_traj = robot.get(timestep_range_r, state[robot_type], padding=np.nan) robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type) + robot_traj_st_t[torch.isnan(robot_traj_st_t)] = 0.0 # Map map_tuple = None diff --git a/trajectron/model/dynamics/unicycle.py b/trajectron/model/dynamics/unicycle.py index 2026a5e..1a83c42 100644 --- a/trajectron/model/dynamics/unicycle.py +++ b/trajectron/model/dynamics/unicycle.py @@ -62,6 +62,12 @@ class Unicycle(Dynamic): ph = control_samples.shape[-2] p_0 = self.initial_conditions['pos'].unsqueeze(1) v_0 = self.initial_conditions['vel'].unsqueeze(1) + + # In case the input is batched because of the robot in online use we repeat this to match the batch size of x. + if p_0.size()[0] != x.size()[0]: + p_0 = p_0.repeat(x.size()[0], 1, 1) + v_0 = v_0.repeat(x.size()[0], 1, 1) + phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0]) phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1))) @@ -193,6 +199,12 @@ class Unicycle(Dynamic): ph = control_dist_dphi_a.mus.shape[-3] p_0 = self.initial_conditions['pos'].unsqueeze(1) v_0 = self.initial_conditions['vel'].unsqueeze(1) + + # In case the input is batched because of the robot in online use we repeat this to match the batch size of x. + if p_0.size()[0] != x.size()[0]: + p_0 = p_0.repeat(x.size()[0], 1, 1) + v_0 = v_0.repeat(x.size()[0], 1, 1) + phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0]) phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1))) diff --git a/trajectron/model/mgcvae.py b/trajectron/model/mgcvae.py index 47d09c2..8760d6d 100644 --- a/trajectron/model/mgcvae.py +++ b/trajectron/model/mgcvae.py @@ -367,10 +367,11 @@ class MultimodalGenerativeCVAE(object): neighbors_edge_value, robot, map) -> (torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor): + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor): """ Encodes input and output tensors for node and robot. @@ -471,7 +472,7 @@ class MultimodalGenerativeCVAE(object): if self.log_writer and (self.curr_iter + 1) % 500 == 0: map_clone = map.clone() map_patch = self.hyperparams['map_encoder'][self.node_type]['patch_size'] - map_clone[:, :, map_patch[1]-5:map_patch[1]+5, map_patch[0]-5:map_patch[0]+5] = 1. + map_clone[:, :, map_patch[1] - 5:map_patch[1] + 5, map_patch[0] - 5:map_patch[0] + 5] = 1. self.log_writer.add_images(f"{self.node_type}/cropped_maps", map_clone, self.curr_iter, dataformats='NCWH') @@ -812,10 +813,10 @@ class MultimodalGenerativeCVAE(object): state = initial_state if self.hyperparams['incl_robot_node']: input_ = torch.cat([zx, - a_0.repeat(num_samples*num_components, 1), - x_nr_t.repeat(num_samples*num_components, 1)], dim=1) + a_0.repeat(num_samples * num_components, 1), + x_nr_t.repeat(num_samples * num_components, 1)], dim=1) else: - input_ = torch.cat([zx, a_0.repeat(num_samples*num_components, 1)], dim=1) + input_ = torch.cat([zx, a_0.repeat(num_samples * num_components, 1)], dim=1) for j in range(ph): h_state = cell(input_, state) @@ -989,7 +990,7 @@ class MultimodalGenerativeCVAE(object): z, kl = self.encoder(mode, x, y_e) log_p_y_xz = self.decoder(mode, x, x_nr_t, y, y_r, n_s_t0, z, - labels, # Loss is calculated on unstandardized label + labels, # Loss is calculated on unstandardized label prediction_horizon, self.hyperparams['k']) diff --git a/trajectron/model/online/online_mgcvae.py b/trajectron/model/online/online_mgcvae.py index fd61b5c..c614c37 100644 --- a/trajectron/model/online/online_mgcvae.py +++ b/trajectron/model/online/online_mgcvae.py @@ -6,6 +6,7 @@ import numpy as np from collections import defaultdict, Counter from model.components import * from model.model_utils import * +from model.dataset import get_relative_robot_traj import model.dynamics as dynamic_module from model.mgcvae import MultimodalGenerativeCVAE from environment.scene_graph import DirectedEdge @@ -48,7 +49,8 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): dynamic_class = getattr(dynamic_module, self.hyperparams['dynamic'][self.node_type]['name']) dyn_limits = hyperparams['dynamic'][self.node_type]['limits'] - self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device, self.model_registrar, self.x_size) + self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device, + self.model_registrar, self.x_size, self.node_type) def create_graphical_model(self): """ @@ -127,11 +129,13 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): if len(self.scene_graph.get_neighbors(self.node, self._get_other_node_type_from_edge(edge_type))) == 0: del self.node_modules[edge_type + '/edge_encoder'] - def obtain_encoded_tensors(self, mode, inputs, inputs_st, inputs_np, robot_present_and_future) -> (torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor): + def obtain_encoded_tensors(self, + mode, + inputs, + inputs_st, + inputs_np, + robot_present_and_future, + maps): x, x_r_t, y_r = None, None, None batch_size = 1 @@ -139,15 +143,19 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): our_inputs_st = inputs_st[self.node] initial_dynamics = dict() - initial_dynamics['pos'] = our_inputs_st[:, 0:2] # TODO: Generalize - initial_dynamics['vel'] = our_inputs_st[:, 2:4] # TODO: Generalize + initial_dynamics['pos'] = our_inputs[:, 0:2] # TODO: Generalize + initial_dynamics['vel'] = our_inputs[:, 2:4] # TODO: Generalize self.dynamic.set_initial_condition(initial_dynamics) ######################################### # Provide basic information to encoders # ######################################### if self.hyperparams['incl_robot_node'] and self.robot is not None: - x_r_t, y_r = self.get_relative_robot_traj(our_inputs, robot_present_and_future, self.robot.type) + robot_present_and_future_st = get_relative_robot_traj(self.env, self.state, + our_inputs, robot_present_and_future, + self.node.type, self.robot.type) + x_r_t = robot_present_and_future_st[..., 0, :] + y_r = robot_present_and_future_st[..., 1:, :] ################## # Encode History # @@ -194,6 +202,21 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): self.TD = {'node_history_encoded': node_history_encoded, 'total_edge_influence': total_edge_influence} + ################ + # Map Encoding # + ################ + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + if self.node not in maps: + # This means the node was removed (it is only being kept around because of the edge removal filter). + me_params = self.hyperparams['map_encoder'][self.node_type] + self.TD['encoded_map'] = torch.zeros((1, me_params['output_size'])) + else: + encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1., + (mode == ModeKeys.TRAIN)) + do = self.hyperparams['map_encoder'][self.node_type]['dropout'] + encoded_map = F.dropout(encoded_map, do, training=(mode == ModeKeys.TRAIN)) + self.TD['encoded_map'] = encoded_map + ###################################### # Concatenate Encoder Outputs into x # ###################################### @@ -207,6 +230,8 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): node_history_encoded = TD['node_history_encoded'] if self.hyperparams['edge_encoding']: total_edge_influence = TD['total_edge_influence'] + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + encoded_map = TD['encoded_map'] if (self.hyperparams['incl_robot_node'] and self.robot is not None @@ -222,6 +247,8 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): node_history_encoded = TD['node_history_encoded'].repeat(robot_future_st.size()[0], 1) if self.hyperparams['edge_encoding']: total_edge_influence = TD['total_edge_influence'].repeat(robot_future_st.size()[0], 1) + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + encoded_map = TD['encoded_map'].repeat(robot_future_st.size()[0], 1) elif self.hyperparams['incl_robot_node'] and self.robot is not None: # Four times because we're trying to mimic a bi-directional RNN's output (which is c and h from both ends). @@ -239,6 +266,9 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): if self.hyperparams['incl_robot_node'] and self.robot is not None: x_concat_list.append(robot_future_encoder) # [bs/nbs, 4*enc_rnn_dim_history] + if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']: + x_concat_list.append(encoded_map) # [bs/nbs, CNN output size] + return torch.cat(x_concat_list, dim=1) def encode_node_history(self, inputs_st): @@ -258,13 +288,19 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): edge_states_list = list() # list of [#of neighbors, max_ht, state_dim] neighbor_states = list() - rel_state = inputs[self.node].cpu().numpy() + orig_rel_state = inputs[self.node].cpu().numpy() for node in connected_nodes[0]: neighbor_state_np = inputs_np[node] # Make State relative to node _, std = self.env.get_standardize_params(self.state[node.type], node_type=node.type) std[0:2] = self.env.attention_radius[edge_type_tuple] + + # TODO: This all makes the unsafe assumption that the first n dims + # refer to the same quantities even for different agent types! + equal_dims = np.min((neighbor_state_np.shape[-1], orig_rel_state.shape[-1])) + rel_state = np.zeros_like(neighbor_state_np) + rel_state[..., :equal_dims] = orig_rel_state[..., :equal_dims] neighbor_state_np_st = self.env.standardize(neighbor_state_np, self.state[node.type], node_type=node.type, @@ -333,7 +369,7 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): else: return outputs[:, 0, :] # [bs, enc_rnn_dim] - def encoder_forward(self, inputs, inputs_st, inputs_np, robot_present_and_future=None): + def encoder_forward(self, inputs, inputs_st, inputs_np, robot_present_and_future=None, maps=None): # Always predicting with the online model. mode = ModeKeys.PREDICT @@ -341,7 +377,9 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): inputs, inputs_st, inputs_np, - robot_present_and_future) + robot_present_and_future, + maps) + self.n_s_t0 = inputs_st[self.node] self.latent.p_dist = self.p_z_x(mode, self.x) @@ -361,27 +399,32 @@ class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): if (self.hyperparams['incl_robot_node'] and self.robot is not None and robot_present_and_future is not None): - x_nr_t, y_r = self.get_relative_robot_traj( - torch.tensor(self.node.get(np.array([self.node.last_timestep]), - self.state[self.node.type], - padding=0.0), - dtype=torch.float, - device=self.device), - robot_present_and_future, - self.robot.type) + our_inputs = torch.tensor(self.node.get(np.array([self.node.last_timestep]), + self.state[self.node.type], + padding=0.0), + dtype=torch.float, + device=self.device) + robot_present_and_future_st = get_relative_robot_traj(self.env, self.state, + our_inputs, robot_present_and_future, + self.node.type, self.robot.type) + x_nr_t = robot_present_and_future_st[..., 0, :] + y_r = robot_present_and_future_st[..., 1:, :] self.x = self.create_encoder_rep(mode, self.TD, x_nr_t, y_r) self.latent.p_dist = self.p_z_x(mode, self.x) + # Making sure n_s_t0 has the same batch size as x_nr_t + self.n_s_t0 = self.n_s_t0[[0]].repeat(x_nr_t.size()[0], 1) + z, num_samples, num_components = self.latent.sample_p(num_samples, mode, most_likely_z=z_mode, full_dist=full_dist, all_z_sep=all_z_sep) - _, our_sampled_future = self.p_y_xz(mode, self.x, x_nr_t, y_r, z, - prediction_horizon, - num_samples, - num_components, - gmm_mode) + y_dist, our_sampled_future = self.p_y_xz(mode, self.x, x_nr_t, y_r, self.n_s_t0, z, + prediction_horizon, + num_samples, + num_components, + gmm_mode) - return our_sampled_future + return y_dist, our_sampled_future diff --git a/trajectron/model/online/online_trajectron.py b/trajectron/model/online/online_trajectron.py index 33c35dd..f1c5063 100644 --- a/trajectron/model/online/online_trajectron.py +++ b/trajectron/model/online/online_trajectron.py @@ -4,7 +4,7 @@ from collections import Counter from model.trajectron import Trajectron from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE from model.model_utils import ModeKeys -from data import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of +from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of class OnlineTrajectron(Trajectron): @@ -58,10 +58,11 @@ class OnlineTrajectron(Trajectron): # Fast-forwarding ourselves to the initial timestep, without running any of the underlying models. for timestep in range(init_timestep + 1): - self.incremental_forward(self.env.scenes[0].get_clipped_pos_dict(timestep, self.hyperparams['state']), - run_models=False) + self.incremental_forward(self.env.scenes[0].get_clipped_input_dict(timestep, self.hyperparams['state']), + maps=None, run_models=False) def incremental_forward(self, new_inputs_dict, + maps, prediction_horizon=0, num_samples=0, robot_present_and_future=None, @@ -80,7 +81,8 @@ class OnlineTrajectron(Trajectron): for node, new_input in new_inputs_dict.items(): if node not in self.node_data: self.node_data[node] = RingBuffer(capacity=self.RING_CAPACITY, - dtype=(float, len(self.state[node.type]['position']))) + dtype=(float, sum(len(self.state[node.type][k]) + for k in self.state[node.type]))) self.node_data[node].append(new_input) if node in self.removed_nodes: @@ -101,17 +103,10 @@ class OnlineTrajectron(Trajectron): # that when it's passed through the LSTMs, the hidden state keeps propagating but the input plays no role # (the NaNs get converted to zeros later on). for node in self.removed_nodes: - self.node_data[node].append(np.full((1, 2), np.nan)) + self.node_data[node].append(np.full((1, self.node_data[node].shape[1]), np.nan)) for node in self.node_data: - x = self.node_data[node][:, 0] - y = self.node_data[node][:, 1] - vx = derivative_of(x, self.env.scenes[0].dt) - vy = derivative_of(y, self.env.scenes[0].dt) - ax = derivative_of(vx, self.env.scenes[0].dt) - ay = derivative_of(vy, self.env.scenes[0].dt) - new_node_data = np.stack([x, y, vx, vy, ax, ay], axis=-1) - node.overwrite_data(new_node_data, + node.overwrite_data(self.node_data[node], None, forward_in_time_on_next_overwrite=(self.node_data[node].shape[0] == self.RING_CAPACITY)) @@ -139,14 +134,15 @@ class OnlineTrajectron(Trajectron): # These next 2 for loops add or remove entire node models. for node in new_nodes: - if node.is_robot: + if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys(): + # Only deal with Models for NodeTypes we want to predict continue self._add_node_model(node) self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors) for node in removed_nodes: - if node.is_robot: + if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys(): continue self._remove_node_model(node) @@ -157,7 +153,8 @@ class OnlineTrajectron(Trajectron): inputs_st = dict() inputs_np = dict() - iter_list = list(self.node_models_dict.keys()) + iter_list = list(self.node_models_dict.keys()) + [node for node in new_inputs_dict + if node.type not in self.pred_state.keys()] if self.env.scenes[0].robot is not None: iter_list.append(self.env.scenes[0].robot) @@ -191,12 +188,15 @@ class OnlineTrajectron(Trajectron): robot_present_and_future = robot_present_and_future[np.newaxis, :] assert robot_present_and_future.shape[1] == prediction_horizon + 1 + robot_present_and_future = torch.tensor(robot_present_and_future, + dtype=torch.float, device=self.device) for node in self.node_models_dict: self.node_models_dict[node].encoder_forward(inputs, inputs_st, inputs_np, - robot_present_and_future) + robot_present_and_future, + maps) # If num_predicted_timesteps or num_samples == 0 then do not run the decoder at all, # just update the encoder LSTMs. @@ -220,21 +220,18 @@ class OnlineTrajectron(Trajectron): full_dist=False, all_z_sep=False): model = self.node_models_dict[node] - predictions = model.decoder_forward(num_predicted_timesteps, - num_samples, - robot_present_and_future=robot_present_and_future, - z_mode=z_mode, - gmm_mode=gmm_mode, - full_dist=full_dist, - all_z_sep=all_z_sep) + prediction_dist, predictions_uns = model.decoder_forward(num_predicted_timesteps, + num_samples, + robot_present_and_future=robot_present_and_future, + z_mode=z_mode, + gmm_mode=gmm_mode, + full_dist=full_dist, + all_z_sep=all_z_sep) - predictions_uns = self.env.unstandardize(predictions.cpu().detach().numpy(), - self.pred_state[node.type.name], - node.type, - mean=self.rel_states[node][..., 0:2]) + predictions_np = predictions_uns.cpu().detach().numpy() # Return will be of shape (batch_size, num_samples, num_predicted_timesteps, 2) - return np.transpose(predictions_uns, (1, 0, 2, 3)) + return prediction_dist, np.transpose(predictions_np, (1, 0, 2, 3)) def sample_model(self, num_predicted_timesteps, num_samples, @@ -262,23 +259,24 @@ class OnlineTrajectron(Trajectron): # No grad since we're predicting always, as evidenced by the line above. with torch.no_grad(): predictions_dict = dict() + prediction_dists = dict() for node in set(self.nodes) - set(self.removed_nodes.keys()): if node.is_robot: continue - predictions_dict[node] = self._run_decoder(node, num_predicted_timesteps, - num_samples, - robot_present_and_future, - z_mode, - gmm_mode, - full_dist, - all_z_sep) + prediction_dists[node], predictions_dict[node] = self._run_decoder(node, num_predicted_timesteps, + num_samples, + robot_present_and_future, + z_mode, + gmm_mode, + full_dist, + all_z_sep) - return predictions_dict + return prediction_dists, predictions_dict def forward(self, init_env, init_timestep, - pos_dicts, # After the initial environment + input_dicts, # After the initial environment num_predicted_timesteps, num_samples, robot_present_and_future=None, @@ -294,8 +292,8 @@ class OnlineTrajectron(Trajectron): self.set_environment(init_env, init_timestep) # Looping through and applying updates to the model. - for i in range(len(pos_dicts)): - self.incremental_forward(pos_dicts[i]) + for i in range(len(input_dicts)): + self.incremental_forward(input_dicts[i]) return self.sample_model(num_predicted_timesteps, num_samples, diff --git a/trajectron/test_online.py b/trajectron/test_online.py index 1539368..3e6cae7 100644 --- a/trajectron/test_online.py +++ b/trajectron/test_online.py @@ -2,7 +2,7 @@ import os import time import json import torch -import pickle +import dill import random import pathlib import evaluation @@ -11,7 +11,7 @@ import visualization as vis from argument_parser import args from model.online.online_trajectron import OnlineTrajectron from model.model_registrar import ModelRegistrar -from data import Environment, Scene +from environment import Environment, Scene import matplotlib.pyplot as plt if not torch.cuda.is_available() or args.device == 'cpu': @@ -58,8 +58,56 @@ def create_online_env(env, hyperparams, scene_idx, init_timestep): robot_type=env.robot_type) +def get_maps_for_input(input_dict, scene, hyperparams): + scene_maps = list() + scene_pts = list() + heading_angles = list() + patch_sizes = list() + nodes_with_maps = list() + for node in input_dict: + if node.type in hyperparams['map_encoder']: + x = input_dict[node] + me_hyp = hyperparams['map_encoder'][node.type] + if 'heading_state_index' in me_hyp: + heading_state_index = me_hyp['heading_state_index'] + # We have to rotate the map in the opposit direction of the agent to match them + if type(heading_state_index) is list: # infer from velocity or heading vector + heading_angle = -np.arctan2(x[-1, heading_state_index[1]], + x[-1, heading_state_index[0]]) * 180 / np.pi + else: + heading_angle = -x[-1, heading_state_index] * 180 / np.pi + else: + heading_angle = None + + scene_map = scene.map[node.type] + map_point = x[-1, :2] + + patch_size = hyperparams['map_encoder'][node.type]['patch_size'] + + scene_maps.append(scene_map) + scene_pts.append(map_point) + heading_angles.append(heading_angle) + patch_sizes.append(patch_size) + nodes_with_maps.append(node) + + if heading_angles[0] is None: + heading_angles = None + else: + heading_angles = torch.Tensor(heading_angles) + + maps = scene_maps[0].get_cropped_maps_from_scene_map_batch(scene_maps, + scene_pts=torch.Tensor(scene_pts), + patch_size=patch_sizes[0], + rotation=heading_angles) + + maps_dict = {node: maps[[i]] for i, node in enumerate(nodes_with_maps)} + return maps_dict + + def main(): - model_dir = os.path.join(args.log_dir, 'models_14_Jan_2020_00_24_21eth_no_rob') + # Choose one of the model directory names under the experiment/*/models folders. + # Possibilities are 'vel_ee', 'int_ee', 'int_ee_me', or 'robot' + model_dir = os.path.join(args.log_dir, 'int_ee') # Load hyperparameters from json config_file = os.path.join(model_dir, args.conf) @@ -78,27 +126,22 @@ def main(): hyperparams['k_eval'] = args.k_eval hyperparams['offline_scene_graph'] = args.offline_scene_graph hyperparams['incl_robot_node'] = args.incl_robot_node - hyperparams['scene_batch_size'] = args.scene_batch_size - hyperparams['node_resample_train'] = args.node_resample_train - hyperparams['node_resample_eval'] = args.node_resample_eval - hyperparams['scene_resample_train'] = args.scene_resample_train - hyperparams['scene_resample_eval'] = args.scene_resample_eval - hyperparams['scene_resample_viz'] = args.scene_resample_viz hyperparams['edge_encoding'] = not args.no_edge_encoding + hyperparams['use_map_encoding'] = args.map_encoding output_save_dir = os.path.join(model_dir, 'pred_figs') pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True) eval_data_path = os.path.join(args.data_dir, args.eval_data_dict) with open(eval_data_path, 'rb') as f: - eval_env = pickle.load(f, encoding='latin1') + eval_env = dill.load(f, encoding='latin1') if eval_env.robot_type is None and hyperparams['incl_robot_node']: eval_env.robot_type = eval_env.NodeType[0] # TODO: Make more general, allow the user to specify? for scene in eval_env.scenes: scene.add_robot_from_nodes(eval_env.robot_type) - print('Loaded evaluation data from %s' % (eval_data_path,)) + print('Loaded data from %s' % (eval_data_path,)) # Creating a dummy environment with a single scene that contains information about the world. # When using this code, feel free to use whichever scene index or initial timestep you wish. @@ -112,7 +155,7 @@ def main(): online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep) model_registrar = ModelRegistrar(model_dir, args.eval_device) - model_registrar.load_models(iter_num=1999) + model_registrar.load_models(iter_num=12) trajectron = OnlineTrajectron(model_registrar, hyperparams, @@ -126,10 +169,14 @@ def main(): trajectron.set_environment(online_env, init_timestep) for timestep in range(init_timestep + 1, eval_scene.timesteps): - pos_dict = eval_scene.get_clipped_pos_dict(timestep, hyperparams['state']) + input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state']) + + maps = None + if hyperparams['use_map_encoding']: + maps = get_maps_for_input(input_dict, eval_scene, hyperparams) robot_present_and_future = None - if eval_scene.robot is not None: + if eval_scene.robot is not None and hyperparams['incl_robot_node']: robot_present_and_future = eval_scene.robot.get(np.array([timestep, timestep + hyperparams['prediction_horizon']]), hyperparams['state'][eval_scene.robot.type], @@ -138,10 +185,12 @@ def main(): # robot_present_and_future += adjustment start = time.time() - preds = trajectron.incremental_forward(pos_dict, - prediction_horizon=12, - num_samples=25, - robot_present_and_future=robot_present_and_future) + dists, preds = trajectron.incremental_forward(input_dict, + maps, + prediction_horizon=6, + num_samples=1, + robot_present_and_future=robot_present_and_future, + full_dist=True) end = time.time() print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start, 1. / (end - start), len(trajectron.nodes), @@ -152,23 +201,16 @@ def main(): if node in preds: detailed_preds_dict[node] = preds[node] - batch_stats = evaluation.compute_batch_statistics({timestep: detailed_preds_dict}, - eval_scene.dt, - max_hl=hyperparams['maximum_history_length'], - ph=hyperparams['prediction_horizon'], - node_type_enum=online_env.NodeType, - prune_ph_to_future=True) - - evaluation.print_batch_errors([batch_stats], 'eval', timestep) - fig, ax = plt.subplots() + vis.visualize_distribution(ax, + dists) vis.visualize_prediction(ax, {timestep: preds}, eval_scene.dt, hyperparams['maximum_history_length'], hyperparams['prediction_horizon']) - if eval_scene.robot is not None: + if eval_scene.robot is not None and hyperparams['incl_robot_node']: robot_for_plotting = eval_scene.robot.get(np.array([timestep, timestep + hyperparams['prediction_horizon']]), hyperparams['state'][eval_scene.robot.type]) diff --git a/trajectron/train.py b/trajectron/train.py index 7857f1a..0e40b5b 100644 --- a/trajectron/train.py +++ b/trajectron/train.py @@ -132,6 +132,9 @@ def main(): return_robot=not args.incl_robot_node) train_data_loader = dict() for node_type_data_set in train_dataset: + if len(node_type_data_set) == 0: + continue + node_type_dataloader = utils.data.DataLoader(node_type_data_set, collate_fn=collate, pin_memory=False if args.device is 'cpu' else True, @@ -172,6 +175,9 @@ def main(): return_robot=not args.incl_robot_node) eval_data_loader = dict() for node_type_data_set in eval_dataset: + if len(node_type_data_set) == 0: + continue + node_type_dataloader = utils.data.DataLoader(node_type_data_set, collate_fn=collate, pin_memory=False if args.eval_device is 'cpu' else True, @@ -285,6 +291,7 @@ def main(): predictions = trajectron.predict(scene, timestep, ph, + min_future_timesteps=ph, z_mode=True, gmm_mode=True, all_z_sep=False, diff --git a/trajectron/visualization/__init__.py b/trajectron/visualization/__init__.py index e9af42c..1f92021 100644 --- a/trajectron/visualization/__init__.py +++ b/trajectron/visualization/__init__.py @@ -1,2 +1,2 @@ -from .visualization import visualize_prediction +from .visualization import visualize_prediction, visualize_distribution from .visualization_utils import plot_boxplots \ No newline at end of file diff --git a/trajectron/visualization/visualization.py b/trajectron/visualization/visualization.py index 50ea2c5..08e1fef 100644 --- a/trajectron/visualization/visualization.py +++ b/trajectron/visualization/visualization.py @@ -1,5 +1,7 @@ from utils import prediction_output_to_trajectories +from scipy import linalg import matplotlib.pyplot as plt +import matplotlib.patches as patches import matplotlib.patheffects as pe import numpy as np import seaborn as sns @@ -86,4 +88,43 @@ def visualize_prediction(ax, if map is not None: ax.imshow(map.as_image(), origin='lower', alpha=0.5) - plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs) \ No newline at end of file + plot_trajectories(ax, prediction_dict, histories_dict, futures_dict, *kwargs) + + +def visualize_distribution(ax, + prediction_distribution_dict, + map=None, + pi_threshold=0.05, + **kwargs): + if map is not None: + ax.imshow(map.as_image(), origin='lower', alpha=0.5) + + for node, pred_dist in prediction_distribution_dict.items(): + if pred_dist.mus.shape[:2] != (1, 1): + return + + means = pred_dist.mus.squeeze().cpu().numpy() + covs = pred_dist.get_covariance_matrix().squeeze().cpu().numpy() + pis = pred_dist.pis_cat_dist.probs.squeeze().cpu().numpy() + + for timestep in range(means.shape[0]): + for z_val in range(means.shape[1]): + mean = means[timestep, z_val] + covar = covs[timestep, z_val] + pi = pis[timestep, z_val] + + if pi < pi_threshold: + continue + + v, w = linalg.eigh(covar) + v = 2. * np.sqrt(2.) * np.sqrt(v) + u = w[0] / linalg.norm(w[0]) + + # Plot an ellipse to show the Gaussian component + angle = np.arctan(u[1] / u[0]) + angle = 180. * angle / np.pi # convert to degrees + ell = patches.Ellipse(mean, v[0], v[1], 180. + angle, color='blue' if node.type.name == 'VEHICLE' else 'orange') + ell.set_edgecolor(None) + ell.set_clip_box(ax.bbox) + ell.set_alpha(pi/10) + ax.add_artist(ell)