Trajectron-plus-plus/trajectron/model/online/online_trajectron.py
2020-12-09 22:42:06 -05:00

304 lines
16 KiB
Python

import torch
import numpy as np
from collections import Counter
from model.trajectron import Trajectron
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
from model.model_utils import ModeKeys
from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
class OnlineTrajectron(Trajectron):
def __init__(self, model_registrar,
hyperparams, device):
super(OnlineTrajectron, self).__init__(model_registrar=model_registrar,
hyperparams=hyperparams,
log_writer=False,
device=device)
self.node_data = dict()
self.scene_graph = None
self.RING_CAPACITY = max(len(self.hyperparams['edge_removal_filter']),
len(self.hyperparams['edge_addition_filter']),
self.hyperparams['maximum_history_length']) + 1
self.rel_states = dict()
self.removed_nodes = Counter()
def __repr__(self):
return f"OnlineTrajectron(# nodes: {len(self.nodes)}, device: {self.device}, hyperparameters: {str(self.hyperparams)}) "
def _add_node_model(self, node):
if node in self.nodes:
raise ValueError('%s was already added to this graph!' % str(node))
self.nodes.add(node)
self.node_models_dict[node] = OnlineMultimodalGenerativeCVAE(self.env,
node,
self.model_registrar,
self.hyperparams,
self.device)
def update_removed_nodes(self):
for node in list(self.removed_nodes.keys()):
if self.removed_nodes[node] >= len(self.hyperparams['edge_removal_filter']):
del self.node_data[node]
del self.removed_nodes[node]
def _remove_node_model(self, node):
if node not in self.nodes:
raise ValueError('%s is not in this graph!' % str(node))
self.nodes.remove(node)
del self.node_models_dict[node]
def set_environment(self, env, init_timestep=0):
self.env = env
self.scene_graph = SceneGraph(edge_radius=self.env.attention_radius)
self.nodes.clear()
self.node_data.clear()
self.node_models_dict.clear()
# Fast-forwarding ourselves to the initial timestep, without running any of the underlying models.
for timestep in range(init_timestep + 1):
self.incremental_forward(self.env.scenes[0].get_clipped_input_dict(timestep, self.hyperparams['state']),
maps=None, run_models=False)
def incremental_forward(self, new_inputs_dict,
maps,
prediction_horizon=0,
num_samples=0,
robot_present_and_future=None,
z_mode=False,
gmm_mode=False,
full_dist=False,
all_z_sep=False,
run_models=True):
# The way this function works is by appending the new datapoints to the
# ends of each of the LSTMs in the graph. Then, we recalculate the
# encoder's output vector h_x and feed that into the decoder to sample new outputs.
mode = ModeKeys.PREDICT
# No grad since we're predicting always, as evidenced by the line above.
with torch.no_grad():
for node, new_input in new_inputs_dict.items():
if node not in self.node_data:
self.node_data[node] = RingBuffer(capacity=self.RING_CAPACITY,
dtype=(float, sum(len(self.state[node.type][k])
for k in self.state[node.type])))
self.node_data[node].append(new_input)
if node in self.removed_nodes:
del self.removed_nodes[node]
# Nodes in self.node_data that aren't in new_inputs_dict were just removed.
newly_removed_nodes = (set(self.node_data.keys()) - set(self.removed_nodes.keys())) - set(
new_inputs_dict.keys())
# We update self.removed_nodes with the newly removed nodes as well as all existing removed nodes to get
# the time since their last removal increased by one.
self.removed_nodes.update(newly_removed_nodes | set(self.removed_nodes.keys()))
# For any nodes that are older than the length of the edge_removal_filter, we can safely clear their data.
self.update_removed_nodes()
# Any remaining removed nodes that aren't yet old enough for data clearing simply have NaNs appended so
# that when it's passed through the LSTMs, the hidden state keeps propagating but the input plays no role
# (the NaNs get converted to zeros later on).
for node in self.removed_nodes:
self.node_data[node].append(np.full((1, self.node_data[node].shape[1]), np.nan))
for node in self.node_data:
node.overwrite_data(self.node_data[node], None,
forward_in_time_on_next_overwrite=(self.node_data[node].shape[0]
== self.RING_CAPACITY))
temp_scene_dict = {k: v[:, 0:2] for k, v in self.node_data.items()}
if not temp_scene_dict:
new_scene_graph = SceneGraph(edge_radius=self.env.attention_radius)
else:
new_scene_graph = TemporalSceneGraph.create_from_temp_scene_dict(
temp_scene_dict,
self.env.attention_radius,
duration=self.RING_CAPACITY,
edge_addition_filter=self.hyperparams['edge_addition_filter'],
edge_removal_filter=self.hyperparams['edge_removal_filter'],
online=True).to_scene_graph(t=self.RING_CAPACITY - 1)
if self.hyperparams['dynamic_edges'] == 'yes':
new_nodes, removed_nodes, new_neighbors, removed_neighbors = new_scene_graph - self.scene_graph
# Aside from updating the scene graph, this for loop updates the graph model
# structure of all affected nodes.
not_removed_nodes = [node for node in self.nodes if node not in removed_nodes]
self.scene_graph = new_scene_graph
for node in not_removed_nodes:
self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors)
# These next 2 for loops add or remove entire node models.
for node in new_nodes:
if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys():
# Only deal with Models for NodeTypes we want to predict
continue
self._add_node_model(node)
self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors)
for node in removed_nodes:
if (node.is_robot and self.hyperparams['incl_robot_node']) or node.type not in self.pred_state.keys():
continue
self._remove_node_model(node)
# This actually updates the node models with the newly observed data.
if run_models:
inputs = dict()
inputs_st = dict()
inputs_np = dict()
iter_list = list(self.node_models_dict.keys()) + [node for node in new_inputs_dict
if node.type not in self.pred_state.keys()]
if self.env.scenes[0].robot is not None:
iter_list.append(self.env.scenes[0].robot)
for node in iter_list:
input_np = node.get(np.array([node.last_timestep, node.last_timestep]), self.state[node.type])
_, std = self.env.get_standardize_params(self.state[node.type.name], node.type)
std[0:2] = self.env.attention_radius[(node.type, node.type)]
rel_state = np.zeros_like(input_np)
rel_state[:, 0:2] = input_np[:, 0:2]
input_st = self.env.standardize(input_np,
self.state[node.type.name],
node.type,
mean=rel_state)
self.rel_states[node] = rel_state
# Converting NaNs to zeros.
input_np[np.isnan(input_np)] = 0
input_st[np.isnan(input_st)] = 0
# Convert to torch tensors
inputs[node] = torch.tensor(input_np, dtype=torch.float, device=self.device)
inputs_st[node] = torch.tensor(input_st, dtype=torch.float, device=self.device)
inputs_np[node] = input_np
# We want tensors of shape (1, ph + 1, state_dim) where the first 1 is the batch size.
if (self.hyperparams['incl_robot_node']
and self.env.scenes[0].robot is not None
and robot_present_and_future is not None):
if len(robot_present_and_future.shape) == 2:
robot_present_and_future = robot_present_and_future[np.newaxis, :]
assert robot_present_and_future.shape[1] == prediction_horizon + 1
robot_present_and_future = torch.tensor(robot_present_and_future,
dtype=torch.float, device=self.device)
for node in self.node_models_dict:
self.node_models_dict[node].encoder_forward(inputs,
inputs_st,
inputs_np,
robot_present_and_future,
maps)
# If num_predicted_timesteps or num_samples == 0 then do not run the decoder at all,
# just update the encoder LSTMs.
if prediction_horizon == 0 or num_samples == 0:
return
return self.sample_model(prediction_horizon,
num_samples,
robot_present_and_future=robot_present_and_future,
z_mode=z_mode,
gmm_mode=gmm_mode,
full_dist=full_dist,
all_z_sep=all_z_sep)
def _run_decoder(self, node,
num_predicted_timesteps,
num_samples,
robot_present_and_future=None,
z_mode=False,
gmm_mode=False,
full_dist=False,
all_z_sep=False):
model = self.node_models_dict[node]
prediction_dist, predictions_uns = model.decoder_forward(num_predicted_timesteps,
num_samples,
robot_present_and_future=robot_present_and_future,
z_mode=z_mode,
gmm_mode=gmm_mode,
full_dist=full_dist,
all_z_sep=all_z_sep)
predictions_np = predictions_uns.cpu().detach().numpy()
# Return will be of shape (batch_size, num_samples, num_predicted_timesteps, 2)
return prediction_dist, np.transpose(predictions_np, (1, 0, 2, 3))
def sample_model(self, num_predicted_timesteps,
num_samples,
robot_present_and_future=None,
z_mode=False,
gmm_mode=False,
full_dist=False,
all_z_sep=False):
# Just start from the encoder output (minus the
# robot future) and get num_samples of
# num_predicted_timesteps-length trajectories.
if num_predicted_timesteps == 0 or num_samples == 0:
return
mode = ModeKeys.PREDICT
# We want tensors of shape (1, ph + 1, state_dim) where the first 1 is the batch size.
if self.hyperparams['incl_robot_node'] and self.env.scenes[
0].robot is not None and robot_present_and_future is not None:
if len(robot_present_and_future.shape) == 2:
robot_present_and_future = robot_present_and_future[np.newaxis, :]
assert robot_present_and_future.shape[1] == num_predicted_timesteps + 1
# No grad since we're predicting always, as evidenced by the line above.
with torch.no_grad():
predictions_dict = dict()
prediction_dists = dict()
for node in set(self.nodes) - set(self.removed_nodes.keys()):
if node.is_robot:
continue
prediction_dists[node], predictions_dict[node] = self._run_decoder(node, num_predicted_timesteps,
num_samples,
robot_present_and_future,
z_mode,
gmm_mode,
full_dist,
all_z_sep)
return prediction_dists, predictions_dict
def forward(self, init_env,
init_timestep,
input_dicts, # After the initial environment
num_predicted_timesteps,
num_samples,
robot_present_and_future=None,
z_mode=False,
gmm_mode=False,
full_dist=False,
all_z_sep=False):
# This is the standard forward prediction function,
# if you have some historical data and just want to
# predict forward some number of timesteps.
# Setting us back to the initial scene graph we had.
self.set_environment(init_env, init_timestep)
# Looping through and applying updates to the model.
for i in range(len(input_dicts)):
self.incremental_forward(input_dicts[i])
return self.sample_model(num_predicted_timesteps,
num_samples,
robot_present_and_future=robot_present_and_future,
z_mode=z_mode,
gmm_mode=gmm_mode,
full_dist=full_dist,
all_z_sep=all_z_sep)