Trajectron-plus-plus/trajectron/model/mgcvae.py

1144 lines
57 KiB
Python

import warnings
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from trajectron.model.components import *
from trajectron.model.model_utils import *
import trajectron.model.dynamics as dynamic_module
from trajectron.environment.scene_graph import DirectedEdge
class MultimodalGenerativeCVAE(object):
def __init__(self,
env,
node_type,
model_registrar,
hyperparams,
device,
edge_types,
log_writer=None):
self.hyperparams = hyperparams
self.env = env
self.node_type = node_type
self.model_registrar = model_registrar
self.log_writer = log_writer
self.device = device
self.edge_types = [edge_type for edge_type in edge_types if edge_type[0] is node_type]
self.curr_iter = 0
self.node_modules = dict()
self.min_hl = self.hyperparams['minimum_history_length']
self.max_hl = self.hyperparams['maximum_history_length']
self.ph = self.hyperparams['prediction_horizon']
self.state = self.hyperparams['state']
self.pred_state = self.hyperparams['pred_state'][node_type]
self.state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[node_type].values()]))
if self.hyperparams['incl_robot_node']:
self.robot_state_length = int(
np.sum([len(entity_dims) for entity_dims in self.state[env.robot_type].values()])
)
self.pred_state_length = int(np.sum([len(entity_dims) for entity_dims in self.pred_state.values()]))
edge_types_str = [DirectedEdge.get_str_from_types(*edge_type) for edge_type in self.edge_types]
self.create_graphical_model(edge_types_str)
dynamic_class = getattr(dynamic_module, hyperparams['dynamic'][self.node_type]['name'])
dyn_limits = hyperparams['dynamic'][self.node_type]['limits']
self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device,
self.model_registrar, self.x_size, self.node_type)
def set_curr_iter(self, curr_iter):
self.curr_iter = curr_iter
def add_submodule(self, name, model_if_absent):
self.node_modules[name] = self.model_registrar.get_model(name, model_if_absent)
def clear_submodules(self):
self.node_modules.clear()
def create_node_models(self):
############################
# Node History Encoder #
############################
self.add_submodule(self.node_type + '/node_history_encoder',
model_if_absent=nn.LSTM(input_size=self.state_length,
hidden_size=self.hyperparams['enc_rnn_dim_history'],
batch_first=True))
###########################
# Node Future Encoder #
###########################
# We'll create this here, but then later check if in training mode.
# Based on that, we'll factor this into the computation graph (or not).
self.add_submodule(self.node_type + '/node_future_encoder',
model_if_absent=nn.LSTM(input_size=self.pred_state_length,
hidden_size=self.hyperparams['enc_rnn_dim_future'],
bidirectional=True,
batch_first=True))
# These are related to how you initialize states for the node future encoder.
self.add_submodule(self.node_type + '/node_future_encoder/initial_h',
model_if_absent=nn.Linear(self.state_length,
self.hyperparams['enc_rnn_dim_future']))
self.add_submodule(self.node_type + '/node_future_encoder/initial_c',
model_if_absent=nn.Linear(self.state_length,
self.hyperparams['enc_rnn_dim_future']))
############################
# Robot Future Encoder #
############################
# We'll create this here, but then later check if we're next to the robot.
# Based on that, we'll factor this into the computation graph (or not).
if self.hyperparams['incl_robot_node']:
self.add_submodule('robot_future_encoder',
model_if_absent=nn.LSTM(input_size=self.robot_state_length,
hidden_size=self.hyperparams['enc_rnn_dim_future'],
bidirectional=True,
batch_first=True))
# These are related to how you initialize states for the robot future encoder.
self.add_submodule('robot_future_encoder/initial_h',
model_if_absent=nn.Linear(self.robot_state_length,
self.hyperparams['enc_rnn_dim_future']))
self.add_submodule('robot_future_encoder/initial_c',
model_if_absent=nn.Linear(self.robot_state_length,
self.hyperparams['enc_rnn_dim_future']))
if self.hyperparams['edge_encoding']:
##############################
# Edge Influence Encoder #
##############################
# NOTE: The edge influence encoding happens during calls
# to forward or incremental_forward, so we don't create
# a model for it here for the max and sum variants.
if self.hyperparams['edge_influence_combine_method'] == 'bi-rnn':
self.add_submodule(self.node_type + '/edge_influence_encoder',
model_if_absent=nn.LSTM(input_size=self.hyperparams['enc_rnn_dim_edge'],
hidden_size=self.hyperparams['enc_rnn_dim_edge_influence'],
bidirectional=True,
batch_first=True))
# Four times because we're trying to mimic a bi-directional
# LSTM's output (which, here, is c and h from both ends).
self.eie_output_dims = 4 * self.hyperparams['enc_rnn_dim_edge_influence']
elif self.hyperparams['edge_influence_combine_method'] == 'attention':
# Chose additive attention because of https://arxiv.org/pdf/1703.03906.pdf
# We calculate an attention context vector using the encoded edges as the "encoder"
# (that we attend _over_)
# and the node history encoder representation as the "decoder state" (that we attend _on_).
self.add_submodule(self.node_type + '/edge_influence_encoder',
model_if_absent=AdditiveAttention(
encoder_hidden_state_dim=self.hyperparams['enc_rnn_dim_edge_influence'],
decoder_hidden_state_dim=self.hyperparams['enc_rnn_dim_history']))
self.eie_output_dims = self.hyperparams['enc_rnn_dim_edge_influence']
###################
# Map Encoder #
###################
if self.hyperparams['use_map_encoding']:
if self.node_type in self.hyperparams['map_encoder']:
me_params = self.hyperparams['map_encoder'][self.node_type]
self.add_submodule(self.node_type + '/map_encoder',
model_if_absent=CNNMapEncoder(me_params['map_channels'],
me_params['hidden_channels'],
me_params['output_size'],
me_params['masks'],
me_params['strides'],
me_params['patch_size']))
################################
# Discrete Latent Variable #
################################
self.latent = DiscreteLatent(self.hyperparams, self.device)
######################################################################
# Various Fully-Connected Layers from Encoder to Latent Variable #
######################################################################
# Node History Encoder
x_size = self.hyperparams['enc_rnn_dim_history']
if self.hyperparams['edge_encoding']:
# Edge Encoder
x_size += self.eie_output_dims
if self.hyperparams['incl_robot_node']:
# Future Conditional Encoder
x_size += 4 * self.hyperparams['enc_rnn_dim_future']
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
# Map Encoder
x_size += self.hyperparams['map_encoder'][self.node_type]['output_size']
z_size = self.hyperparams['N'] * self.hyperparams['K']
if self.hyperparams['p_z_x_MLP_dims'] is not None:
self.add_submodule(self.node_type + '/p_z_x',
model_if_absent=nn.Linear(x_size, self.hyperparams['p_z_x_MLP_dims']))
hx_size = self.hyperparams['p_z_x_MLP_dims']
else:
hx_size = x_size
self.add_submodule(self.node_type + '/hx_to_z',
model_if_absent=nn.Linear(hx_size, self.latent.z_dim))
if self.hyperparams['q_z_xy_MLP_dims'] is not None:
self.add_submodule(self.node_type + '/q_z_xy',
# Node Future Encoder
model_if_absent=nn.Linear(x_size + 4 * self.hyperparams['enc_rnn_dim_future'],
self.hyperparams['q_z_xy_MLP_dims']))
hxy_size = self.hyperparams['q_z_xy_MLP_dims']
else:
# Node Future Encoder
hxy_size = x_size + 4 * self.hyperparams['enc_rnn_dim_future']
self.add_submodule(self.node_type + '/hxy_to_z',
model_if_absent=nn.Linear(hxy_size, self.latent.z_dim))
####################
# Decoder LSTM #
####################
if self.hyperparams['incl_robot_node']:
decoder_input_dims = self.pred_state_length + self.robot_state_length + z_size + x_size
else:
decoder_input_dims = self.pred_state_length + z_size + x_size
self.add_submodule(self.node_type + '/decoder/state_action',
model_if_absent=nn.Sequential(
nn.Linear(self.state_length, self.pred_state_length)))
self.add_submodule(self.node_type + '/decoder/rnn_cell',
model_if_absent=nn.GRUCell(decoder_input_dims, self.hyperparams['dec_rnn_dim']))
self.add_submodule(self.node_type + '/decoder/initial_h',
model_if_absent=nn.Linear(z_size + x_size, self.hyperparams['dec_rnn_dim']))
###################
# Decoder GMM #
###################
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_log_pis',
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
self.hyperparams['GMM_components']))
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_mus',
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
self.hyperparams['GMM_components'] * self.pred_state_length))
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_log_sigmas',
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
self.hyperparams['GMM_components'] * self.pred_state_length))
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_corrs',
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
self.hyperparams['GMM_components']))
self.x_size = x_size
self.z_size = z_size
def create_edge_models(self, edge_types):
for edge_type in edge_types:
neighbor_state_length = int(
np.sum([len(entity_dims) for entity_dims in self.state[edge_type.split('->')[1]].values()]))
if self.hyperparams['edge_state_combine_method'] == 'pointnet':
self.add_submodule(edge_type + '/pointnet_encoder',
model_if_absent=nn.Sequential(
nn.Linear(self.state_length, 2 * self.state_length),
nn.ReLU(),
nn.Linear(2 * self.state_length, 2 * self.state_length),
nn.ReLU()))
edge_encoder_input_size = 2 * self.state_length + self.state_length
elif self.hyperparams['edge_state_combine_method'] == 'attention':
self.add_submodule(self.node_type + '/edge_attention_combine',
model_if_absent=TemporallyBatchedAdditiveAttention(
encoder_hidden_state_dim=self.state_length,
decoder_hidden_state_dim=self.state_length))
edge_encoder_input_size = self.state_length + neighbor_state_length
else:
edge_encoder_input_size = self.state_length + neighbor_state_length
self.add_submodule(edge_type + '/edge_encoder',
model_if_absent=nn.LSTM(input_size=edge_encoder_input_size,
hidden_size=self.hyperparams['enc_rnn_dim_edge'],
batch_first=True))
def create_graphical_model(self, edge_types):
"""
Creates or queries all trainable components.
:param edge_types: List containing strings for all possible edge types for the node type.
:return: None
"""
self.clear_submodules()
############################
# Everything but Edges #
############################
self.create_node_models()
#####################
# Edge Encoders #
#####################
if self.hyperparams['edge_encoding']:
self.create_edge_models(edge_types)
for name, module in self.node_modules.items():
module.to(self.device)
def create_new_scheduler(self, name, annealer, annealer_kws, creation_condition=True):
value_scheduler = None
rsetattr(self, name + '_scheduler', value_scheduler)
if creation_condition:
annealer_kws['device'] = self.device
value_annealer = annealer(annealer_kws)
rsetattr(self, name + '_annealer', value_annealer)
# This is the value that we'll update on each call of
# step_annealers().
rsetattr(self, name, value_annealer(0).clone().detach())
dummy_optimizer = optim.Optimizer([rgetattr(self, name)], {'lr': value_annealer(0).clone().detach()})
rsetattr(self, name + '_optimizer', dummy_optimizer)
value_scheduler = CustomLR(dummy_optimizer,
value_annealer)
rsetattr(self, name + '_scheduler', value_scheduler)
self.schedulers.append(value_scheduler)
self.annealed_vars.append(name)
def set_annealing_params(self):
self.schedulers = list()
self.annealed_vars = list()
self.create_new_scheduler(name='kl_weight',
annealer=sigmoid_anneal,
annealer_kws={
'start': self.hyperparams['kl_weight_start'],
'finish': self.hyperparams['kl_weight'],
'center_step': self.hyperparams['kl_crossover'],
'steps_lo_to_hi': self.hyperparams['kl_crossover'] / self.hyperparams[
'kl_sigmoid_divisor']
})
self.create_new_scheduler(name='latent.temp',
annealer=exp_anneal,
annealer_kws={
'start': self.hyperparams['tau_init'],
'finish': self.hyperparams['tau_final'],
'rate': self.hyperparams['tau_decay_rate']
})
self.create_new_scheduler(name='latent.z_logit_clip',
annealer=sigmoid_anneal,
annealer_kws={
'start': self.hyperparams['z_logit_clip_start'],
'finish': self.hyperparams['z_logit_clip_final'],
'center_step': self.hyperparams['z_logit_clip_crossover'],
'steps_lo_to_hi': self.hyperparams['z_logit_clip_crossover'] / self.hyperparams[
'z_logit_clip_divisor']
},
creation_condition=self.hyperparams['use_z_logit_clipping'])
def step_annealers(self):
# This should manage all of the step-wise changed
# parameters automatically.
for idx, annealed_var in enumerate(self.annealed_vars):
if rgetattr(self, annealed_var + '_scheduler') is not None:
# First we step the scheduler.
with warnings.catch_warnings(): # We use a dummy optimizer: Warning because no .step() was called on it
warnings.simplefilter("ignore")
rgetattr(self, annealed_var + '_scheduler').step()
# Then we set the annealed vars' value.
rsetattr(self, annealed_var, rgetattr(self, annealed_var + '_optimizer').param_groups[0]['lr'])
self.summarize_annealers()
def summarize_annealers(self):
if self.log_writer is not None:
for annealed_var in self.annealed_vars:
if rgetattr(self, annealed_var) is not None:
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), annealed_var.replace('.', '/')),
rgetattr(self, annealed_var), self.curr_iter)
def obtain_encoded_tensors(self,
mode,
inputs,
inputs_st,
labels,
labels_st,
first_history_indices,
neighbors,
neighbors_edge_value,
robot,
map) -> (torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor):
"""
Encodes input and output tensors for node and robot.
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
:param inputs: Input tensor including the state for each agent over time [bs, t, state].
:param inputs_st: Standardized input tensor.
:param labels: Label tensor including the label output for each agent over time [bs, t, pred_state].
:param labels_st: Standardized label tensor.
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
:param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time.
[[bs, t, neighbor state]]
:param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]]
:param robot: Standardized robot state over time. [bs, t, robot_state]
:param map: Tensor of Map information. [bs, channels, x, y]
:return: tuple(x, x_nr_t, y_e, y_r, y, n_s_t0)
WHERE
- x: Encoded input / condition tensor to the CVAE x_e.
- x_r_t: Robot state (if robot is in scene).
- y_e: Encoded label / future of the node.
- y_r: Encoded future of the robot.
- y: Label / future of the node.
- n_s_t0: Standardized current state of the node.
"""
x, x_r_t, y_e, y_r, y = None, None, None, None, None
initial_dynamics = dict()
batch_size = inputs.shape[0]
#########################################
# Provide basic information to encoders #
#########################################
node_history = inputs
node_present_state = inputs[:, -1]
node_pos = inputs[:, -1, 0:2]
node_vel = inputs[:, -1, 2:4]
node_history_st = inputs_st
node_present_state_st = inputs_st[:, -1]
node_pos_st = inputs_st[:, -1, 0:2]
node_vel_st = inputs_st[:, -1, 2:4]
n_s_t0 = node_present_state_st
initial_dynamics['pos'] = node_pos
initial_dynamics['vel'] = node_vel
self.dynamic.set_initial_condition(initial_dynamics)
if self.hyperparams['incl_robot_node']:
x_r_t, y_r = robot[..., 0, :], robot[..., 1:, :]
##################
# Encode History #
##################
node_history_encoded = self.encode_node_history(mode,
node_history_st,
first_history_indices)
##################
# Encode Present #
##################
node_present = node_present_state_st # [bs, state_dim]
##################
# Encode Future #
##################
if mode != ModeKeys.PREDICT:
y = labels_st
##############################
# Encode Node Edges per Type #
##############################
if self.hyperparams['edge_encoding']:
node_edges_encoded = list()
for edge_type in self.edge_types:
# Encode edges for given edge type
encoded_edges_type = self.encode_edge(mode,
node_history,
node_history_st,
edge_type,
neighbors[edge_type],
neighbors_edge_value[edge_type],
first_history_indices)
node_edges_encoded.append(encoded_edges_type) # List of [bs/nbs, enc_rnn_dim]
#####################
# Encode Node Edges #
#####################
total_edge_influence = self.encode_total_edge_influence(mode,
node_edges_encoded,
node_history_encoded,
batch_size)
################
# Map Encoding #
################
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
if self.log_writer and (self.curr_iter + 1) % 500 == 0:
map_clone = map.clone()
map_patch = self.hyperparams['map_encoder'][self.node_type]['patch_size']
map_clone[:, :, map_patch[1] - 5:map_patch[1] + 5, map_patch[0] - 5:map_patch[0] + 5] = 1.
self.log_writer.add_images(f"{self.node_type}/cropped_maps", map_clone,
self.curr_iter, dataformats='NCWH')
encoded_map = self.node_modules[self.node_type + '/map_encoder'](map * 2. - 1., (mode == ModeKeys.TRAIN))
do = self.hyperparams['map_encoder'][self.node_type]['dropout']
encoded_map = F.dropout(encoded_map, do, training=(mode == ModeKeys.TRAIN))
######################################
# Concatenate Encoder Outputs into x #
######################################
x_concat_list = list()
# Every node has an edge-influence encoder (which could just be zero).
if self.hyperparams['edge_encoding']:
x_concat_list.append(total_edge_influence) # [bs/nbs, 4*enc_rnn_dim]
# Every node has a history encoder.
x_concat_list.append(node_history_encoded) # [bs/nbs, enc_rnn_dim_history]
if self.hyperparams['incl_robot_node']:
robot_future_encoder = self.encode_robot_future(mode, x_r_t, y_r)
x_concat_list.append(robot_future_encoder)
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
if self.log_writer:
self.log_writer.add_scalar(f"{self.node_type}/encoded_map_max",
torch.max(torch.abs(encoded_map)), self.curr_iter)
x_concat_list.append(encoded_map)
x = torch.cat(x_concat_list, dim=1)
if mode == ModeKeys.TRAIN or mode == ModeKeys.EVAL:
y_e = self.encode_node_future(mode, node_present, y)
return x, x_r_t, y_e, y_r, y, n_s_t0
def encode_node_history(self, mode, node_hist, first_history_indices):
"""
Encodes the nodes history.
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
:param node_hist: Historic and current state of the node. [bs, mhl, state]
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
:return: Encoded node history tensor. [bs, enc_rnn_dim]
"""
outputs, _ = run_lstm_on_variable_length_seqs(self.node_modules[self.node_type + '/node_history_encoder'],
original_seqs=node_hist,
lower_indices=first_history_indices)
outputs = F.dropout(outputs,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN)) # [bs, max_time, enc_rnn_dim]
last_index_per_sequence = -(first_history_indices + 1)
return outputs[torch.arange(first_history_indices.shape[0]), last_index_per_sequence]
def encode_edge(self,
mode,
node_history,
node_history_st,
edge_type,
neighbors,
neighbors_edge_value,
first_history_indices):
max_hl = self.hyperparams['maximum_history_length']
edge_states_list = list() # list of [#of neighbors, max_ht, state_dim]
for i, neighbor_states in enumerate(neighbors): # Get neighbors for timestep in batch
if len(neighbor_states) == 0: # There are no neighbors for edge type # TODO necessary?
neighbor_state_length = int(
np.sum([len(entity_dims) for entity_dims in self.state[edge_type[1]].values()])
)
edge_states_list.append(torch.zeros((1, max_hl + 1, neighbor_state_length), device=self.device))
else:
edge_states_list.append(torch.stack(neighbor_states, dim=0).to(self.device))
if self.hyperparams['edge_state_combine_method'] == 'sum':
# Used in Structural-RNN to combine edges as well.
op_applied_edge_states_list = list()
for neighbors_state in edge_states_list:
op_applied_edge_states_list.append(torch.sum(neighbors_state, dim=0))
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
if self.hyperparams['dynamic_edges'] == 'yes':
# Should now be (bs, time, 1)
op_applied_edge_mask_list = list()
for edge_value in neighbors_edge_value:
op_applied_edge_mask_list.append(torch.clamp(torch.sum(edge_value.to(self.device),
dim=0, keepdim=True), max=1.))
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
elif self.hyperparams['edge_state_combine_method'] == 'max':
# Used in NLP, e.g. max over word embeddings in a sentence.
op_applied_edge_states_list = list()
for neighbors_state in edge_states_list:
op_applied_edge_states_list.append(torch.max(neighbors_state, dim=0))
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
if self.hyperparams['dynamic_edges'] == 'yes':
# Should now be (bs, time, 1)
op_applied_edge_mask_list = list()
for edge_value in neighbors_edge_value:
op_applied_edge_mask_list.append(torch.clamp(torch.max(edge_value.to(self.device),
dim=0, keepdim=True), max=1.))
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
elif self.hyperparams['edge_state_combine_method'] == 'mean':
# Used in NLP, e.g. mean over word embeddings in a sentence.
op_applied_edge_states_list = list()
for neighbors_state in edge_states_list:
op_applied_edge_states_list.append(torch.mean(neighbors_state, dim=0))
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
if self.hyperparams['dynamic_edges'] == 'yes':
# Should now be (bs, time, 1)
op_applied_edge_mask_list = list()
for edge_value in neighbors_edge_value:
op_applied_edge_mask_list.append(torch.clamp(torch.mean(edge_value.to(self.device),
dim=0, keepdim=True), max=1.))
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
joint_history = torch.cat([combined_neighbors, node_history_st], dim=-1)
outputs, _ = run_lstm_on_variable_length_seqs(
self.node_modules[DirectedEdge.get_str_from_types(*edge_type) + '/edge_encoder'],
original_seqs=joint_history,
lower_indices=first_history_indices
)
outputs = F.dropout(outputs,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN)) # [bs, max_time, enc_rnn_dim]
last_index_per_sequence = -(first_history_indices + 1)
ret = outputs[torch.arange(last_index_per_sequence.shape[0]), last_index_per_sequence]
if self.hyperparams['dynamic_edges'] == 'yes':
return ret * combined_edge_masks
else:
return ret
def encode_total_edge_influence(self, mode, encoded_edges, node_history_encoder, batch_size):
if self.hyperparams['edge_influence_combine_method'] == 'sum':
stacked_encoded_edges = torch.stack(encoded_edges, dim=0)
combined_edges = torch.sum(stacked_encoded_edges, dim=0)
elif self.hyperparams['edge_influence_combine_method'] == 'mean':
stacked_encoded_edges = torch.stack(encoded_edges, dim=0)
combined_edges = torch.mean(stacked_encoded_edges, dim=0)
elif self.hyperparams['edge_influence_combine_method'] == 'max':
stacked_encoded_edges = torch.stack(encoded_edges, dim=0)
combined_edges = torch.max(stacked_encoded_edges, dim=0)
elif self.hyperparams['edge_influence_combine_method'] == 'bi-rnn':
if len(encoded_edges) == 0:
combined_edges = torch.zeros((batch_size, self.eie_output_dims), device=self.device)
else:
# axis=1 because then we get size [batch_size, max_time, depth]
encoded_edges = torch.stack(encoded_edges, dim=1)
_, state = self.node_modules[self.node_type + '/edge_influence_encoder'](encoded_edges)
combined_edges = unpack_RNN_state(state)
combined_edges = F.dropout(combined_edges,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
elif self.hyperparams['edge_influence_combine_method'] == 'attention':
# Used in Social Attention (https://arxiv.org/abs/1710.04689)
if len(encoded_edges) == 0:
combined_edges = torch.zeros((batch_size, self.eie_output_dims), device=self.device)
else:
# axis=1 because then we get size [batch_size, max_time, depth]
encoded_edges = torch.stack(encoded_edges, dim=1)
combined_edges, _ = self.node_modules[self.node_type + '/edge_influence_encoder'](encoded_edges,
node_history_encoder)
combined_edges = F.dropout(combined_edges,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
return combined_edges
def encode_node_future(self, mode, node_present, node_future) -> torch.Tensor:
"""
Encodes the node future (during training) using a bi-directional LSTM
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
:param node_present: Current state of the node. [bs, state]
:param node_future: Future states of the node. [bs, ph, state]
:return: Encoded future.
"""
initial_h_model = self.node_modules[self.node_type + '/node_future_encoder/initial_h']
initial_c_model = self.node_modules[self.node_type + '/node_future_encoder/initial_c']
# Here we're initializing the forward hidden states,
# but zeroing the backward ones.
initial_h = initial_h_model(node_present)
initial_h = torch.stack([initial_h, torch.zeros_like(initial_h, device=self.device)], dim=0)
initial_c = initial_c_model(node_present)
initial_c = torch.stack([initial_c, torch.zeros_like(initial_c, device=self.device)], dim=0)
initial_state = (initial_h, initial_c)
_, state = self.node_modules[self.node_type + '/node_future_encoder'](node_future, initial_state)
state = unpack_RNN_state(state)
state = F.dropout(state,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
return state
def encode_robot_future(self, mode, robot_present, robot_future) -> torch.Tensor:
"""
Encodes the robot future (during training) using a bi-directional LSTM
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
:param robot_present: Current state of the robot. [bs, state]
:param robot_future: Future states of the robot. [bs, ph, state]
:return: Encoded future.
"""
initial_h_model = self.node_modules['robot_future_encoder/initial_h']
initial_c_model = self.node_modules['robot_future_encoder/initial_c']
# Here we're initializing the forward hidden states,
# but zeroing the backward ones.
initial_h = initial_h_model(robot_present)
initial_h = torch.stack([initial_h, torch.zeros_like(initial_h, device=self.device)], dim=0)
initial_c = initial_c_model(robot_present)
initial_c = torch.stack([initial_c, torch.zeros_like(initial_c, device=self.device)], dim=0)
initial_state = (initial_h, initial_c)
_, state = self.node_modules['robot_future_encoder'](robot_future, initial_state)
state = unpack_RNN_state(state)
state = F.dropout(state,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
return state
def q_z_xy(self, mode, x, y_e) -> torch.Tensor:
r"""
.. math:: q_\phi(z \mid \mathbf{x}_i, \mathbf{y}_i)
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
:param x: Input / Condition tensor.
:param y_e: Encoded future tensor.
:return: Latent distribution of the CVAE.
"""
xy = torch.cat([x, y_e], dim=1)
if self.hyperparams['q_z_xy_MLP_dims'] is not None:
dense = self.node_modules[self.node_type + '/q_z_xy']
h = F.dropout(F.relu(dense(xy)),
p=1. - self.hyperparams['MLP_dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
else:
h = xy
to_latent = self.node_modules[self.node_type + '/hxy_to_z']
return self.latent.dist_from_h(to_latent(h), mode)
def p_z_x(self, mode, x):
r"""
.. math:: p_\theta(z \mid \mathbf{x}_i)
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
:param x: Input / Condition tensor.
:return: Latent distribution of the CVAE.
"""
if self.hyperparams['p_z_x_MLP_dims'] is not None:
dense = self.node_modules[self.node_type + '/p_z_x']
h = F.dropout(F.relu(dense(x)),
p=1. - self.hyperparams['MLP_dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
else:
h = x
to_latent = self.node_modules[self.node_type + '/hx_to_z']
return self.latent.dist_from_h(to_latent(h), mode)
def project_to_GMM_params(self, tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
"""
Projects tensor to parameters of a GMM with N components and D dimensions.
:param tensor: Input tensor.
:return: tuple(log_pis, mus, log_sigmas, corrs)
WHERE
- log_pis: Weight (logarithm) of each GMM component. [N]
- mus: Mean of each GMM component. [N, D]
- log_sigmas: Standard Deviation (logarithm) of each GMM component. [N, D]
- corrs: Correlation between the GMM components. [N]
"""
log_pis = self.node_modules[self.node_type + '/decoder/proj_to_GMM_log_pis'](tensor)
mus = self.node_modules[self.node_type + '/decoder/proj_to_GMM_mus'](tensor)
log_sigmas = self.node_modules[self.node_type + '/decoder/proj_to_GMM_log_sigmas'](tensor)
corrs = torch.tanh(self.node_modules[self.node_type + '/decoder/proj_to_GMM_corrs'](tensor))
return log_pis, mus, log_sigmas, corrs
def p_y_xz(self, mode, x, x_nr_t, y_r, n_s_t0, z_stacked, prediction_horizon,
num_samples, num_components=1, gmm_mode=False):
r"""
.. math:: p_\psi(\mathbf{y}_i \mid \mathbf{x}_i, z)
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
:param x: Input / Condition tensor.
:param x_nr_t: Joint state of node and robot (if robot is in scene).
:param y: Future tensor.
:param y_r: Encoded future tensor.
:param n_s_t0: Standardized current state of the node.
:param z_stacked: Stacked latent state. [num_samples_z * num_samples_gmm, bs, latent_state]
:param prediction_horizon: Number of prediction timesteps.
:param num_samples: Number of samples from the latent space.
:param num_components: Number of GMM components.
:param gmm_mode: If True: The mode of the GMM is sampled.
:return: GMM2D. If mode is Predict, also samples from the GMM.
"""
ph = prediction_horizon
pred_dim = self.pred_state_length
z = torch.reshape(z_stacked, (-1, self.latent.z_dim))
zx = torch.cat([z, x.repeat(num_samples * num_components, 1)], dim=1)
cell = self.node_modules[self.node_type + '/decoder/rnn_cell']
initial_h_model = self.node_modules[self.node_type + '/decoder/initial_h']
initial_state = initial_h_model(zx)
log_pis, mus, log_sigmas, corrs, a_sample = [], [], [], [], []
# Infer initial action state for node from current state
a_0 = self.node_modules[self.node_type + '/decoder/state_action'](n_s_t0)
state = initial_state
if self.hyperparams['incl_robot_node']:
input_ = torch.cat([zx,
a_0.repeat(num_samples * num_components, 1),
x_nr_t.repeat(num_samples * num_components, 1)], dim=1)
else:
input_ = torch.cat([zx, a_0.repeat(num_samples * num_components, 1)], dim=1)
for j in range(ph):
h_state = cell(input_, state)
log_pi_t, mu_t, log_sigma_t, corr_t = self.project_to_GMM_params(h_state)
gmm = GMM2D(log_pi_t, mu_t, log_sigma_t, corr_t) # [k;bs, pred_dim]
if mode == ModeKeys.PREDICT and gmm_mode:
a_t = gmm.mode()
else:
a_t = gmm.rsample()
if num_components > 1:
if mode == ModeKeys.PREDICT:
log_pis.append(self.latent.p_dist.logits.repeat(num_samples, 1, 1))
else:
log_pis.append(self.latent.q_dist.logits.repeat(num_samples, 1, 1))
else:
log_pis.append(
torch.ones_like(corr_t.reshape(num_samples, num_components, -1).permute(0, 2, 1).reshape(-1, 1))
)
mus.append(
mu_t.reshape(
num_samples, num_components, -1, 2
).permute(0, 2, 1, 3).reshape(-1, 2 * num_components)
)
log_sigmas.append(
log_sigma_t.reshape(
num_samples, num_components, -1, 2
).permute(0, 2, 1, 3).reshape(-1, 2 * num_components))
corrs.append(
corr_t.reshape(
num_samples, num_components, -1
).permute(0, 2, 1).reshape(-1, num_components))
if self.hyperparams['incl_robot_node']:
dec_inputs = [zx, a_t, y_r[:, j].repeat(num_samples * num_components, 1)]
else:
dec_inputs = [zx, a_t]
input_ = torch.cat(dec_inputs, dim=1)
state = h_state
log_pis = torch.stack(log_pis, dim=1)
mus = torch.stack(mus, dim=1)
log_sigmas = torch.stack(log_sigmas, dim=1)
corrs = torch.stack(corrs, dim=1)
a_dist = GMM2D(torch.reshape(log_pis, [num_samples, -1, ph, num_components]),
torch.reshape(mus, [num_samples, -1, ph, num_components * pred_dim]),
torch.reshape(log_sigmas, [num_samples, -1, ph, num_components * pred_dim]),
torch.reshape(corrs, [num_samples, -1, ph, num_components]))
if self.hyperparams['dynamic'][self.node_type]['distribution']:
y_dist = self.dynamic.integrate_distribution(a_dist, x)
else:
y_dist = a_dist
if mode == ModeKeys.PREDICT:
if gmm_mode:
a_sample = a_dist.mode()
else:
a_sample = a_dist.rsample()
sampled_future = self.dynamic.integrate_samples(a_sample, x)
return y_dist, sampled_future
else:
return y_dist
def encoder(self, mode, x, y_e, num_samples=None):
"""
Encoder of the CVAE.
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
:param x: Input / Condition tensor.
:param y_e: Encoded future tensor.
:param num_samples: Number of samples from the latent space during Prediction.
:return: tuple(z, kl_obj)
WHERE
- z: Samples from the latent space.
- kl_obj: KL Divergenze between q and p
"""
if mode == ModeKeys.TRAIN:
sample_ct = self.hyperparams['k']
elif mode == ModeKeys.EVAL:
sample_ct = self.hyperparams['k_eval']
elif mode == ModeKeys.PREDICT:
sample_ct = num_samples
if num_samples is None:
raise ValueError("num_samples cannot be None with mode == PREDICT.")
self.latent.q_dist = self.q_z_xy(mode, x, y_e)
self.latent.p_dist = self.p_z_x(mode, x)
z = self.latent.sample_q(sample_ct, mode)
if mode == ModeKeys.TRAIN:
kl_obj = self.latent.kl_q_p(self.log_writer, '%s' % str(self.node_type), self.curr_iter)
if self.log_writer is not None:
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'kl'), kl_obj, self.curr_iter)
else:
kl_obj = None
return z, kl_obj
def decoder(self, mode, x, x_nr_t, y, y_r, n_s_t0, z, labels, prediction_horizon, num_samples):
"""
Decoder of the CVAE.
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
:param x: Input / Condition tensor.
:param x: Input / Condition tensor.
:param x_nr_t: Joint state of node and robot (if robot is in scene).
:param y: Future tensor.
:param y_r: Encoded future tensor.
:param n_s_t0: Standardized current state of the node.
:param z: Stacked latent state.
:param prediction_horizon: Number of prediction timesteps.
:param num_samples: Number of samples from the latent space.
:return: Log probability of y over p.
"""
num_components = self.hyperparams['N'] * self.hyperparams['K']
y_dist = self.p_y_xz(mode, x, x_nr_t, y_r, n_s_t0, z,
prediction_horizon, num_samples, num_components=num_components)
log_p_yt_xz = torch.clamp(y_dist.log_prob(labels), max=self.hyperparams['log_p_yt_xz_max'])
if self.hyperparams['log_histograms'] and self.log_writer is not None:
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'log_p_yt_xz'), log_p_yt_xz, self.curr_iter)
log_p_y_xz = torch.sum(log_p_yt_xz, dim=2)
return log_p_y_xz
def train_loss(self,
inputs,
inputs_st,
first_history_indices,
labels,
labels_st,
neighbors,
neighbors_edge_value,
robot,
map,
prediction_horizon) -> torch.Tensor:
"""
Calculates the training loss for a batch.
:param inputs: Input tensor including the state for each agent over time [bs, t, state].
:param inputs_st: Standardized input tensor.
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
:param labels: Label tensor including the label output for each agent over time [bs, t, pred_state].
:param labels_st: Standardized label tensor.
:param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time.
[[bs, t, neighbor state]]
:param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]]
:param robot: Standardized robot state over time. [bs, t, robot_state]
:param map: Tensor of Map information. [bs, channels, x, y]
:param prediction_horizon: Number of prediction timesteps.
:return: Scalar tensor -> nll loss
"""
mode = ModeKeys.TRAIN
x, x_nr_t, y_e, y_r, y, n_s_t0 = self.obtain_encoded_tensors(mode=mode,
inputs=inputs,
inputs_st=inputs_st,
labels=labels,
labels_st=labels_st,
first_history_indices=first_history_indices,
neighbors=neighbors,
neighbors_edge_value=neighbors_edge_value,
robot=robot,
map=map)
z, kl = self.encoder(mode, x, y_e)
log_p_y_xz = self.decoder(mode, x, x_nr_t, y, y_r, n_s_t0, z,
labels, # Loss is calculated on unstandardized label
prediction_horizon,
self.hyperparams['k'])
log_p_y_xz_mean = torch.mean(log_p_y_xz, dim=0) # [nbs]
log_likelihood = torch.mean(log_p_y_xz_mean)
mutual_inf_q = mutual_inf_mc(self.latent.q_dist)
mutual_inf_p = mutual_inf_mc(self.latent.p_dist)
ELBO = log_likelihood - self.kl_weight * kl + 1. * mutual_inf_p
loss = -ELBO
if self.hyperparams['log_histograms'] and self.log_writer is not None:
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'log_p_y_xz'),
log_p_y_xz_mean,
self.curr_iter)
if self.log_writer is not None:
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'mutual_information_q'),
mutual_inf_q,
self.curr_iter)
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'mutual_information_p'),
mutual_inf_p,
self.curr_iter)
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'log_likelihood'),
log_likelihood,
self.curr_iter)
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'loss'),
loss,
self.curr_iter)
if self.hyperparams['log_histograms']:
self.latent.summarize_for_tensorboard(self.log_writer, str(self.node_type), self.curr_iter)
return loss
def eval_loss(self,
inputs,
inputs_st,
first_history_indices,
labels,
labels_st,
neighbors,
neighbors_edge_value,
robot,
map,
prediction_horizon) -> torch.Tensor:
"""
Calculates the evaluation loss for a batch.
:param inputs: Input tensor including the state for each agent over time [bs, t, state].
:param inputs_st: Standardized input tensor.
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
:param labels: Label tensor including the label output for each agent over time [bs, t, pred_state].
:param labels_st: Standardized label tensor.
:param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time.
[[bs, t, neighbor state]]
:param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]]
:param robot: Standardized robot state over time. [bs, t, robot_state]
:param map: Tensor of Map information. [bs, channels, x, y]
:param prediction_horizon: Number of prediction timesteps.
:return: tuple(nll_q_is, nll_p, nll_exact, nll_sampled)
"""
mode = ModeKeys.EVAL
x, x_nr_t, y_e, y_r, y, n_s_t0 = self.obtain_encoded_tensors(mode=mode,
inputs=inputs,
inputs_st=inputs_st,
labels=labels,
labels_st=labels_st,
first_history_indices=first_history_indices,
neighbors=neighbors,
neighbors_edge_value=neighbors_edge_value,
robot=robot,
map=map)
num_components = self.hyperparams['N'] * self.hyperparams['K']
### Importance sampled NLL estimate
z, _ = self.encoder(mode, x, y_e) # [k_eval, nbs, N*K]
z = self.latent.sample_p(1, mode, full_dist=True)
y_dist, _ = self.p_y_xz(ModeKeys.PREDICT, x, x_nr_t, y_r, n_s_t0, z,
prediction_horizon, num_samples=1, num_components=num_components)
# We use unstandardized labels to compute the loss
log_p_yt_xz = torch.clamp(y_dist.log_prob(labels), max=self.hyperparams['log_p_yt_xz_max'])
log_p_y_xz = torch.sum(log_p_yt_xz, dim=2)
log_p_y_xz_mean = torch.mean(log_p_y_xz, dim=0) # [nbs]
log_likelihood = torch.mean(log_p_y_xz_mean)
nll = -log_likelihood
return nll
def predict(self,
inputs,
inputs_st,
first_history_indices,
neighbors,
neighbors_edge_value,
robot,
map,
prediction_horizon,
num_samples,
z_mode=False,
gmm_mode=False,
full_dist=True,
all_z_sep=False):
"""
Predicts the future of a batch of nodes.
:param inputs: Input tensor including the state for each agent over time [bs, t, state].
:param inputs_st: Standardized input tensor.
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
:param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time.
[[bs, t, neighbor state]]
:param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]]
:param robot: Standardized robot state over time. [bs, t, robot_state]
:param map: Tensor of Map information. [bs, channels, x, y]
:param prediction_horizon: Number of prediction timesteps.
:param num_samples: Number of samples from the latent space.
:param z_mode: If True: Select the most likely latent state.
:param gmm_mode: If True: The mode of the GMM is sampled.
:param all_z_sep: Samples each latent mode individually without merging them into a GMM.
:param full_dist: Samples all latent states and merges them into a GMM as output.
:return:
"""
mode = ModeKeys.PREDICT
x, x_nr_t, _, y_r, _, n_s_t0 = self.obtain_encoded_tensors(mode=mode,
inputs=inputs,
inputs_st=inputs_st,
labels=None,
labels_st=None,
first_history_indices=first_history_indices,
neighbors=neighbors,
neighbors_edge_value=neighbors_edge_value,
robot=robot,
map=map)
self.latent.p_dist = self.p_z_x(mode, x)
z, num_samples, num_components = self.latent.sample_p(num_samples,
mode,
most_likely_z=z_mode,
full_dist=full_dist,
all_z_sep=all_z_sep)
_, our_sampled_future = self.p_y_xz(mode, x, x_nr_t, y_r, n_s_t0, z,
prediction_horizon,
num_samples,
num_components,
gmm_mode)
return our_sampled_future