Trajectron-plus-plus/trajectron/model/online/online_mgcvae.py

431 lines
22 KiB
Python

import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import defaultdict, Counter
from trajectron.model.components import *
from trajectron.model.model_utils import *
from trajectron.model.dataset import get_relative_robot_traj
import trajectron.model.dynamics as dynamic_module
from trajectron.model.mgcvae import MultimodalGenerativeCVAE
from trajectron.environment.scene_graph import DirectedEdge
from trajectron.environment.node_type import NodeType
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
def __init__(self,
env,
node,
model_registrar,
hyperparams,
device):
self.hyperparams = hyperparams
self.node = node
self.node_type = self.node.type
if len(env.scenes) != 1:
raise ValueError("Passed in Environment has number of scenes != 1")
self.robot = env.scenes[0].robot
self.model_registrar = model_registrar
self.device = device
self.node_modules = dict()
self.env = env
self.scene_graph = None
self.state = self.hyperparams['state']
self.pred_state = self.hyperparams['pred_state'][self.node.type]
self.state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[self.node.type].values()]))
if self.hyperparams['incl_robot_node']:
self.robot_state_length = int(
np.sum([len(entity_dims) for entity_dims in self.state[self.robot.type].values()]))
self.pred_state_length = int(np.sum([len(entity_dims) for entity_dims in self.pred_state.values()]))
self.curr_hidden_states = dict()
self.edge_types = Counter()
self.create_graphical_model()
dynamic_class = getattr(dynamic_module, self.hyperparams['dynamic'][self.node_type]['name'])
dyn_limits = hyperparams['dynamic'][self.node_type]['limits']
self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device,
self.model_registrar, self.x_size, self.node_type)
def create_graphical_model(self):
"""
Creates or queries all trainable components.
:return: None
"""
self.clear_submodules()
############################
# Everything but Edges #
############################
self.create_node_models()
for name, module in self.node_modules.items():
module.to(self.device)
def update_graph(self, new_scene_graph, new_neighbors, removed_neighbors):
self.scene_graph = new_scene_graph
if self.node in new_neighbors:
for edge_type, new_neighbor_nodes in new_neighbors[self.node].items():
self.add_edge_model(edge_type)
self.edge_types += Counter({edge_type: len(new_neighbor_nodes)})
if self.node in removed_neighbors:
for edge_type, removed_neighbor_nodes in removed_neighbors[self.node].items():
self.remove_edge_model(edge_type)
self.edge_types -= Counter({edge_type: len(removed_neighbor_nodes)})
def get_edge_to(self, other_node):
return DirectedEdge(self.node, other_node)
def add_edge_model(self, edge_type):
if self.hyperparams['edge_encoding']:
if edge_type + '/edge_encoder' not in self.node_modules:
neighbor_state_length = int(
np.sum([len(entity_dims) for entity_dims in
self.state[self._get_other_node_type_from_edge(edge_type)].values()]))
if self.hyperparams['edge_state_combine_method'] == 'pointnet':
self.add_submodule(edge_type + '/pointnet_encoder',
model_if_absent=nn.Sequential(
nn.Linear(self.state_length, 2 * self.state_length),
nn.ReLU(),
nn.Linear(2 * self.state_length, 2 * self.state_length),
nn.ReLU()))
edge_encoder_input_size = 2 * self.state_length + self.state_length
elif self.hyperparams['edge_state_combine_method'] == 'attention':
self.add_submodule(self.node.type + '/edge_attention_combine',
model_if_absent=TemporallyBatchedAdditiveAttention(
encoder_hidden_state_dim=self.state_length,
decoder_hidden_state_dim=self.state_length))
edge_encoder_input_size = self.state_length + neighbor_state_length
else:
edge_encoder_input_size = self.state_length + neighbor_state_length
self.add_submodule(edge_type + '/edge_encoder',
model_if_absent=nn.LSTM(input_size=edge_encoder_input_size,
hidden_size=self.hyperparams['enc_rnn_dim_edge'],
batch_first=True))
def _get_other_node_type_from_edge(self, edge_type_str):
n2_type_str = edge_type_str.split('->')[1]
return NodeType(n2_type_str, self.env.node_type_list.index(n2_type_str) + 1)
def _get_edge_type_from_str(self, edge_type_str):
n1_type_str, n2_type_str = edge_type_str.split('->')
return (NodeType(n1_type_str, self.env.node_type_list.index(n1_type_str) + 1),
NodeType(n2_type_str, self.env.node_type_list.index(n2_type_str) + 1))
def remove_edge_model(self, edge_type):
if self.hyperparams['edge_encoding']:
if len(self.scene_graph.get_neighbors(self.node, self._get_other_node_type_from_edge(edge_type))) == 0:
del self.node_modules[edge_type + '/edge_encoder']
def obtain_encoded_tensors(self,
mode,
inputs,
inputs_st,
inputs_np,
robot_present_and_future,
maps):
x, x_r_t, y_r = None, None, None
batch_size = 1
our_inputs = inputs[self.node]
our_inputs_st = inputs_st[self.node]
initial_dynamics = dict()
initial_dynamics['pos'] = our_inputs[:, 0:2] # TODO: Generalize
initial_dynamics['vel'] = our_inputs[:, 2:4] # TODO: Generalize
self.dynamic.set_initial_condition(initial_dynamics)
#########################################
# Provide basic information to encoders #
#########################################
if self.hyperparams['incl_robot_node'] and self.robot is not None:
robot_present_and_future_st = get_relative_robot_traj(self.env, self.state,
our_inputs, robot_present_and_future,
self.node.type, self.robot.type)
x_r_t = robot_present_and_future_st[..., 0, :]
y_r = robot_present_and_future_st[..., 1:, :]
##################
# Encode History #
##################
node_history_encoded = self.encode_node_history(our_inputs_st)
##############################
# Encode Node Edges per Type #
##############################
total_edge_influence = None
if self.hyperparams['edge_encoding']:
node_edges_encoded = list()
for edge_type in self.edge_types:
connected_nodes_batched = list()
edge_masks_batched = list()
# We get all nodes which are connected to the current node for the current timestep
connected_nodes_batched.append(self.scene_graph.get_neighbors(self.node,
self._get_other_node_type_from_edge(
edge_type)))
if self.hyperparams['dynamic_edges'] == 'yes':
# We get the edge masks for the current node at the current timestep
edge_masks_for_node = self.scene_graph.get_edge_scaling(self.node)
edge_masks_batched.append(torch.tensor(edge_masks_for_node, dtype=torch.float, device=self.device))
# Encode edges for given edge type
encoded_edges_type = self.encode_edge(inputs,
inputs_st,
inputs_np,
edge_type,
connected_nodes_batched,
edge_masks_batched)
node_edges_encoded.append(encoded_edges_type) # List of [bs/nbs, enc_rnn_dim]
#####################
# Encode Node Edges #
#####################
total_edge_influence = self.encode_total_edge_influence(mode,
node_edges_encoded,
node_history_encoded,
batch_size)
self.TD = {'node_history_encoded': node_history_encoded,
'total_edge_influence': total_edge_influence}
################
# Map Encoding #
################
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
if self.node not in maps:
# This means the node was removed (it is only being kept around because of the edge removal filter).
me_params = self.hyperparams['map_encoder'][self.node_type]
self.TD['encoded_map'] = torch.zeros((1, me_params['output_size']))
else:
encoded_map = self.node_modules[self.node_type + '/map_encoder'](maps[self.node] * 2. - 1.,
(mode == ModeKeys.TRAIN))
do = self.hyperparams['map_encoder'][self.node_type]['dropout']
encoded_map = F.dropout(encoded_map, do, training=(mode == ModeKeys.TRAIN))
self.TD['encoded_map'] = encoded_map
######################################
# Concatenate Encoder Outputs into x #
######################################
return self.create_encoder_rep(mode, self.TD, x_r_t, y_r)
def create_encoder_rep(self, mode,
TD,
robot_present_st,
robot_future_st):
# Unpacking TD
node_history_encoded = TD['node_history_encoded']
if self.hyperparams['edge_encoding']:
total_edge_influence = TD['total_edge_influence']
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
encoded_map = TD['encoded_map']
if (self.hyperparams['incl_robot_node']
and self.robot is not None
and robot_future_st is not None
and robot_present_st is not None):
robot_future_encoder = self.encode_robot_future(mode, robot_present_st, robot_future_st)
# Tiling for multiple samples
# This tiling is done because:
# a) we must consider the prediction case where there are many candidate robot future actions,
# b) the edge and history encoders are all the same regardless of which candidate future robot action
# we're evaluating.
node_history_encoded = TD['node_history_encoded'].repeat(robot_future_st.size()[0], 1)
if self.hyperparams['edge_encoding']:
total_edge_influence = TD['total_edge_influence'].repeat(robot_future_st.size()[0], 1)
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
encoded_map = TD['encoded_map'].repeat(robot_future_st.size()[0], 1)
elif self.hyperparams['incl_robot_node'] and self.robot is not None:
# Four times because we're trying to mimic a bi-directional RNN's output (which is c and h from both ends).
robot_future_encoder = torch.zeros([1, 4 * self.hyperparams['enc_rnn_dim_future']], device=self.device)
x_concat_list = list()
# Every node has an edge-influence encoder (which could just be zero).
if self.hyperparams['edge_encoding']:
x_concat_list.append(total_edge_influence) # [bs/nbs, 4*enc_rnn_dim]
# Every node has a history encoder.
x_concat_list.append(node_history_encoded) # [bs/nbs, enc_rnn_dim_history]
if self.hyperparams['incl_robot_node'] and self.robot is not None:
x_concat_list.append(robot_future_encoder) # [bs/nbs, 4*enc_rnn_dim_history]
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
x_concat_list.append(encoded_map) # [bs/nbs, CNN output size]
return torch.cat(x_concat_list, dim=1)
def encode_node_history(self, inputs_st):
new_state = torch.unsqueeze(inputs_st, dim=1) # [bs, 1, state_dim]
if self.node.type + '/node_history_encoder' not in self.curr_hidden_states:
outputs, self.curr_hidden_states[self.node.type + '/node_history_encoder'] = self.node_modules[
self.node.type + '/node_history_encoder'](new_state)
else:
outputs, self.curr_hidden_states[self.node.type + '/node_history_encoder'] = self.node_modules[
self.node.type + '/node_history_encoder'](new_state, self.curr_hidden_states[
self.node.type + '/node_history_encoder'])
return outputs[:, 0, :]
def encode_edge(self, inputs, inputs_st, inputs_np, edge_type, connected_nodes, edge_masks):
edge_type_tuple = self._get_edge_type_from_str(edge_type)
edge_states_list = list() # list of [#of neighbors, max_ht, state_dim]
neighbor_states = list()
orig_rel_state = inputs[self.node].cpu().numpy()
for node in connected_nodes[0]:
neighbor_state_np = inputs_np[node]
# Make State relative to node
_, std = self.env.get_standardize_params(self.state[node.type], node_type=node.type)
std[0:2] = self.env.attention_radius[edge_type_tuple]
# TODO: This all makes the unsafe assumption that the first n dims
# refer to the same quantities even for different agent types!
equal_dims = np.min((neighbor_state_np.shape[-1], orig_rel_state.shape[-1]))
rel_state = np.zeros_like(neighbor_state_np)
rel_state[..., :equal_dims] = orig_rel_state[..., :equal_dims]
neighbor_state_np_st = self.env.standardize(neighbor_state_np,
self.state[node.type],
node_type=node.type,
mean=rel_state,
std=std)
neighbor_state = torch.tensor(neighbor_state_np_st).float().to(self.device)
neighbor_states.append(neighbor_state)
if len(neighbor_states) == 0: # There are no neighbors for edge type # TODO necessary?
neighbor_state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[edge_type[1]].values()]))
edge_states_list.append(torch.zeros((1, 1, neighbor_state_length), device=self.device))
else:
edge_states_list.append(torch.stack(neighbor_states, dim=0))
if self.hyperparams['edge_state_combine_method'] == 'sum':
# Used in Structural-RNN to combine edges as well.
op_applied_edge_states_list = list()
for neighbors_state in edge_states_list:
op_applied_edge_states_list.append(torch.sum(neighbors_state, dim=0))
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
if self.hyperparams['dynamic_edges'] == 'yes':
# Should now be (bs, time, 1)
op_applied_edge_mask_list = list()
for edge_mask in edge_masks:
op_applied_edge_mask_list.append(torch.clamp(torch.sum(edge_mask, dim=0, keepdim=True), max=1.))
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
elif self.hyperparams['edge_state_combine_method'] == 'max':
# Used in NLP, e.g. max over word embeddings in a sentence.
op_applied_edge_states_list = list()
for neighbors_state in edge_states_list:
op_applied_edge_states_list.append(torch.max(neighbors_state, dim=0))
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
if self.hyperparams['dynamic_edges'] == 'yes':
# Should now be (bs, time, 1)
op_applied_edge_mask_list = list()
for edge_mask in edge_masks:
op_applied_edge_mask_list.append(torch.clamp(torch.max(edge_mask, dim=0, keepdim=True), max=1.))
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
elif self.hyperparams['edge_state_combine_method'] == 'mean':
# Used in NLP, e.g. mean over word embeddings in a sentence.
op_applied_edge_states_list = list()
for neighbors_state in edge_states_list:
op_applied_edge_states_list.append(torch.mean(neighbors_state, dim=0))
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
if self.hyperparams['dynamic_edges'] == 'yes':
# Should now be (bs, time, 1)
op_applied_edge_mask_list = list()
for edge_mask in edge_masks:
op_applied_edge_mask_list.append(torch.clamp(torch.mean(edge_mask, dim=0, keepdim=True), max=1.))
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
joint_history = torch.cat([combined_neighbors, torch.unsqueeze(inputs_st[self.node], dim=0)], dim=-1)
if edge_type + '/edge_encoder' not in self.curr_hidden_states:
outputs, self.curr_hidden_states[edge_type + '/edge_encoder'] = self.node_modules[
edge_type + '/edge_encoder'](joint_history)
else:
outputs, self.curr_hidden_states[edge_type + '/edge_encoder'] = self.node_modules[
edge_type + '/edge_encoder'](joint_history, self.curr_hidden_states[edge_type + '/edge_encoder'])
if self.hyperparams['dynamic_edges'] == 'yes':
return outputs[:, 0, :] * combined_edge_masks
else:
return outputs[:, 0, :] # [bs, enc_rnn_dim]
def encoder_forward(self, inputs, inputs_st, inputs_np, robot_present_and_future=None, maps=None):
# Always predicting with the online model.
mode = ModeKeys.PREDICT
self.x = self.obtain_encoded_tensors(mode,
inputs,
inputs_st,
inputs_np,
robot_present_and_future,
maps)
self.n_s_t0 = inputs_st[self.node]
self.latent.p_dist = self.p_z_x(mode, self.x)
# robot_future_st is optional here since you can use the same one from encoder_forward,
# but if it's given then we'll re-run that part of the model (if the node is adjacent to the robot).
def decoder_forward(self, prediction_horizon,
num_samples,
robot_present_and_future=None,
z_mode=False,
gmm_mode=False,
full_dist=False,
all_z_sep=False):
# Always predicting with the online model.
mode = ModeKeys.PREDICT
x_nr_t, y_r = None, None
if (self.hyperparams['incl_robot_node']
and self.robot is not None
and robot_present_and_future is not None):
our_inputs = torch.tensor(self.node.get(np.array([self.node.last_timestep]),
self.state[self.node.type],
padding=0.0),
dtype=torch.float,
device=self.device)
robot_present_and_future_st = get_relative_robot_traj(self.env, self.state,
our_inputs, robot_present_and_future,
self.node.type, self.robot.type)
x_nr_t = robot_present_and_future_st[..., 0, :]
y_r = robot_present_and_future_st[..., 1:, :]
self.x = self.create_encoder_rep(mode, self.TD, x_nr_t, y_r)
self.latent.p_dist = self.p_z_x(mode, self.x)
# Making sure n_s_t0 has the same batch size as x_nr_t
self.n_s_t0 = self.n_s_t0[[0]].repeat(x_nr_t.size()[0], 1)
z, num_samples, num_components = self.latent.sample_p(num_samples,
mode,
most_likely_z=z_mode,
full_dist=full_dist,
all_z_sep=all_z_sep)
y_dist, our_sampled_future = self.p_y_xz(mode, self.x, x_nr_t, y_r, self.n_s_t0, z,
prediction_horizon,
num_samples,
num_components,
gmm_mode)
return y_dist, our_sampled_future