2020-04-06 01:43:49 +00:00
|
|
|
import warnings
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.optim as optim
|
|
|
|
from model.components import *
|
|
|
|
from model.model_utils import *
|
|
|
|
import model.dynamics as dynamic_module
|
|
|
|
from environment.scene_graph import DirectedEdge
|
|
|
|
|
|
|
|
|
|
|
|
class MultimodalGenerativeCVAE(object):
|
|
|
|
def __init__(self,
|
|
|
|
env,
|
|
|
|
node_type,
|
|
|
|
model_registrar,
|
|
|
|
hyperparams,
|
|
|
|
device,
|
|
|
|
edge_types,
|
|
|
|
log_writer=None):
|
|
|
|
self.hyperparams = hyperparams
|
|
|
|
self.env = env
|
|
|
|
self.node_type = node_type
|
|
|
|
self.model_registrar = model_registrar
|
|
|
|
self.log_writer = log_writer
|
|
|
|
self.device = device
|
|
|
|
self.edge_types = [edge_type for edge_type in edge_types if edge_type[0] is node_type]
|
|
|
|
self.curr_iter = 0
|
|
|
|
|
|
|
|
self.node_modules = dict()
|
|
|
|
|
|
|
|
self.min_hl = self.hyperparams['minimum_history_length']
|
|
|
|
self.max_hl = self.hyperparams['maximum_history_length']
|
|
|
|
self.ph = self.hyperparams['prediction_horizon']
|
|
|
|
self.state = self.hyperparams['state']
|
|
|
|
self.pred_state = self.hyperparams['pred_state'][node_type]
|
|
|
|
self.state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[node_type].values()]))
|
|
|
|
if self.hyperparams['incl_robot_node']:
|
|
|
|
self.robot_state_length = int(
|
|
|
|
np.sum([len(entity_dims) for entity_dims in self.state[env.robot_type].values()])
|
|
|
|
)
|
|
|
|
self.pred_state_length = int(np.sum([len(entity_dims) for entity_dims in self.pred_state.values()]))
|
|
|
|
|
|
|
|
edge_types_str = [DirectedEdge.get_str_from_types(*edge_type) for edge_type in self.edge_types]
|
|
|
|
self.create_graphical_model(edge_types_str)
|
|
|
|
|
|
|
|
dynamic_class = getattr(dynamic_module, hyperparams['dynamic'][self.node_type]['name'])
|
|
|
|
dyn_limits = hyperparams['dynamic'][self.node_type]['limits']
|
|
|
|
self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device,
|
|
|
|
self.model_registrar, self.x_size, self.node_type)
|
|
|
|
|
|
|
|
def set_curr_iter(self, curr_iter):
|
|
|
|
self.curr_iter = curr_iter
|
|
|
|
|
|
|
|
def add_submodule(self, name, model_if_absent):
|
|
|
|
self.node_modules[name] = self.model_registrar.get_model(name, model_if_absent)
|
|
|
|
|
|
|
|
def clear_submodules(self):
|
|
|
|
self.node_modules.clear()
|
|
|
|
|
|
|
|
def create_node_models(self):
|
|
|
|
############################
|
|
|
|
# Node History Encoder #
|
|
|
|
############################
|
|
|
|
self.add_submodule(self.node_type + '/node_history_encoder',
|
|
|
|
model_if_absent=nn.LSTM(input_size=self.state_length,
|
|
|
|
hidden_size=self.hyperparams['enc_rnn_dim_history'],
|
|
|
|
batch_first=True))
|
|
|
|
|
|
|
|
###########################
|
|
|
|
# Node Future Encoder #
|
|
|
|
###########################
|
|
|
|
# We'll create this here, but then later check if in training mode.
|
|
|
|
# Based on that, we'll factor this into the computation graph (or not).
|
|
|
|
self.add_submodule(self.node_type + '/node_future_encoder',
|
|
|
|
model_if_absent=nn.LSTM(input_size=self.pred_state_length,
|
|
|
|
hidden_size=self.hyperparams['enc_rnn_dim_future'],
|
|
|
|
bidirectional=True,
|
|
|
|
batch_first=True))
|
|
|
|
# These are related to how you initialize states for the node future encoder.
|
|
|
|
self.add_submodule(self.node_type + '/node_future_encoder/initial_h',
|
|
|
|
model_if_absent=nn.Linear(self.state_length,
|
|
|
|
self.hyperparams['enc_rnn_dim_future']))
|
|
|
|
self.add_submodule(self.node_type + '/node_future_encoder/initial_c',
|
|
|
|
model_if_absent=nn.Linear(self.state_length,
|
|
|
|
self.hyperparams['enc_rnn_dim_future']))
|
|
|
|
|
|
|
|
############################
|
|
|
|
# Robot Future Encoder #
|
|
|
|
############################
|
|
|
|
# We'll create this here, but then later check if we're next to the robot.
|
|
|
|
# Based on that, we'll factor this into the computation graph (or not).
|
|
|
|
if self.hyperparams['incl_robot_node']:
|
|
|
|
self.add_submodule('robot_future_encoder',
|
|
|
|
model_if_absent=nn.LSTM(input_size=self.robot_state_length,
|
|
|
|
hidden_size=self.hyperparams['enc_rnn_dim_future'],
|
|
|
|
bidirectional=True,
|
|
|
|
batch_first=True))
|
|
|
|
# These are related to how you initialize states for the robot future encoder.
|
|
|
|
self.add_submodule('robot_future_encoder/initial_h',
|
|
|
|
model_if_absent=nn.Linear(self.robot_state_length,
|
|
|
|
self.hyperparams['enc_rnn_dim_future']))
|
|
|
|
self.add_submodule('robot_future_encoder/initial_c',
|
|
|
|
model_if_absent=nn.Linear(self.robot_state_length,
|
|
|
|
self.hyperparams['enc_rnn_dim_future']))
|
|
|
|
|
|
|
|
if self.hyperparams['edge_encoding']:
|
|
|
|
##############################
|
|
|
|
# Edge Influence Encoder #
|
|
|
|
##############################
|
|
|
|
# NOTE: The edge influence encoding happens during calls
|
|
|
|
# to forward or incremental_forward, so we don't create
|
|
|
|
# a model for it here for the max and sum variants.
|
|
|
|
if self.hyperparams['edge_influence_combine_method'] == 'bi-rnn':
|
|
|
|
self.add_submodule(self.node_type + '/edge_influence_encoder',
|
|
|
|
model_if_absent=nn.LSTM(input_size=self.hyperparams['enc_rnn_dim_edge'],
|
|
|
|
hidden_size=self.hyperparams['enc_rnn_dim_edge_influence'],
|
|
|
|
bidirectional=True,
|
|
|
|
batch_first=True))
|
|
|
|
|
|
|
|
# Four times because we're trying to mimic a bi-directional
|
|
|
|
# LSTM's output (which, here, is c and h from both ends).
|
|
|
|
self.eie_output_dims = 4 * self.hyperparams['enc_rnn_dim_edge_influence']
|
|
|
|
|
|
|
|
elif self.hyperparams['edge_influence_combine_method'] == 'attention':
|
|
|
|
# Chose additive attention because of https://arxiv.org/pdf/1703.03906.pdf
|
|
|
|
# We calculate an attention context vector using the encoded edges as the "encoder"
|
|
|
|
# (that we attend _over_)
|
|
|
|
# and the node history encoder representation as the "decoder state" (that we attend _on_).
|
|
|
|
self.add_submodule(self.node_type + '/edge_influence_encoder',
|
|
|
|
model_if_absent=AdditiveAttention(
|
|
|
|
encoder_hidden_state_dim=self.hyperparams['enc_rnn_dim_edge_influence'],
|
|
|
|
decoder_hidden_state_dim=self.hyperparams['enc_rnn_dim_history']))
|
|
|
|
|
|
|
|
self.eie_output_dims = self.hyperparams['enc_rnn_dim_edge_influence']
|
|
|
|
|
|
|
|
###################
|
|
|
|
# Map Encoder #
|
|
|
|
###################
|
|
|
|
if self.hyperparams['use_map_encoding']:
|
|
|
|
if self.node_type in self.hyperparams['map_encoder']:
|
|
|
|
me_params = self.hyperparams['map_encoder'][self.node_type]
|
|
|
|
self.add_submodule(self.node_type + '/map_encoder',
|
|
|
|
model_if_absent=CNNMapEncoder(me_params['map_channels'],
|
|
|
|
me_params['hidden_channels'],
|
|
|
|
me_params['output_size'],
|
|
|
|
me_params['masks'],
|
|
|
|
me_params['strides'],
|
|
|
|
me_params['patch_size']))
|
|
|
|
|
|
|
|
################################
|
|
|
|
# Discrete Latent Variable #
|
|
|
|
################################
|
|
|
|
self.latent = DiscreteLatent(self.hyperparams, self.device)
|
|
|
|
|
|
|
|
######################################################################
|
|
|
|
# Various Fully-Connected Layers from Encoder to Latent Variable #
|
|
|
|
######################################################################
|
|
|
|
# Node History Encoder
|
|
|
|
x_size = self.hyperparams['enc_rnn_dim_history']
|
|
|
|
if self.hyperparams['edge_encoding']:
|
|
|
|
# Edge Encoder
|
|
|
|
x_size += self.eie_output_dims
|
|
|
|
if self.hyperparams['incl_robot_node']:
|
|
|
|
# Future Conditional Encoder
|
|
|
|
x_size += 4 * self.hyperparams['enc_rnn_dim_future']
|
|
|
|
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
|
|
|
|
# Map Encoder
|
|
|
|
x_size += self.hyperparams['map_encoder'][self.node_type]['output_size']
|
|
|
|
|
|
|
|
z_size = self.hyperparams['N'] * self.hyperparams['K']
|
|
|
|
|
|
|
|
if self.hyperparams['p_z_x_MLP_dims'] is not None:
|
|
|
|
self.add_submodule(self.node_type + '/p_z_x',
|
|
|
|
model_if_absent=nn.Linear(x_size, self.hyperparams['p_z_x_MLP_dims']))
|
|
|
|
hx_size = self.hyperparams['p_z_x_MLP_dims']
|
|
|
|
else:
|
|
|
|
hx_size = x_size
|
|
|
|
|
|
|
|
self.add_submodule(self.node_type + '/hx_to_z',
|
|
|
|
model_if_absent=nn.Linear(hx_size, self.latent.z_dim))
|
|
|
|
|
|
|
|
if self.hyperparams['q_z_xy_MLP_dims'] is not None:
|
|
|
|
self.add_submodule(self.node_type + '/q_z_xy',
|
|
|
|
# Node Future Encoder
|
|
|
|
model_if_absent=nn.Linear(x_size + 4 * self.hyperparams['enc_rnn_dim_future'],
|
|
|
|
self.hyperparams['q_z_xy_MLP_dims']))
|
|
|
|
hxy_size = self.hyperparams['q_z_xy_MLP_dims']
|
|
|
|
else:
|
|
|
|
# Node Future Encoder
|
|
|
|
hxy_size = x_size + 4 * self.hyperparams['enc_rnn_dim_future']
|
|
|
|
|
|
|
|
self.add_submodule(self.node_type + '/hxy_to_z',
|
|
|
|
model_if_absent=nn.Linear(hxy_size, self.latent.z_dim))
|
|
|
|
|
|
|
|
####################
|
|
|
|
# Decoder LSTM #
|
|
|
|
####################
|
|
|
|
if self.hyperparams['incl_robot_node']:
|
|
|
|
decoder_input_dims = self.pred_state_length + self.robot_state_length + z_size + x_size
|
|
|
|
else:
|
|
|
|
decoder_input_dims = self.pred_state_length + z_size + x_size
|
|
|
|
|
|
|
|
self.add_submodule(self.node_type + '/decoder/state_action',
|
|
|
|
model_if_absent=nn.Sequential(
|
|
|
|
nn.Linear(self.state_length, self.pred_state_length)))
|
|
|
|
|
|
|
|
self.add_submodule(self.node_type + '/decoder/rnn_cell',
|
|
|
|
model_if_absent=nn.GRUCell(decoder_input_dims, self.hyperparams['dec_rnn_dim']))
|
|
|
|
self.add_submodule(self.node_type + '/decoder/initial_h',
|
|
|
|
model_if_absent=nn.Linear(z_size + x_size, self.hyperparams['dec_rnn_dim']))
|
|
|
|
|
|
|
|
###################
|
|
|
|
# Decoder GMM #
|
|
|
|
###################
|
|
|
|
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_log_pis',
|
|
|
|
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
|
|
|
|
self.hyperparams['GMM_components']))
|
|
|
|
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_mus',
|
|
|
|
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
|
|
|
|
self.hyperparams['GMM_components'] * self.pred_state_length))
|
|
|
|
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_log_sigmas',
|
|
|
|
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
|
|
|
|
self.hyperparams['GMM_components'] * self.pred_state_length))
|
|
|
|
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_corrs',
|
|
|
|
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
|
|
|
|
self.hyperparams['GMM_components']))
|
|
|
|
|
|
|
|
self.x_size = x_size
|
|
|
|
self.z_size = z_size
|
|
|
|
|
|
|
|
def create_edge_models(self, edge_types):
|
|
|
|
for edge_type in edge_types:
|
|
|
|
neighbor_state_length = int(
|
|
|
|
np.sum([len(entity_dims) for entity_dims in self.state[edge_type.split('->')[1]].values()]))
|
|
|
|
if self.hyperparams['edge_state_combine_method'] == 'pointnet':
|
|
|
|
self.add_submodule(edge_type + '/pointnet_encoder',
|
|
|
|
model_if_absent=nn.Sequential(
|
|
|
|
nn.Linear(self.state_length, 2 * self.state_length),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Linear(2 * self.state_length, 2 * self.state_length),
|
|
|
|
nn.ReLU()))
|
|
|
|
|
|
|
|
edge_encoder_input_size = 2 * self.state_length + self.state_length
|
|
|
|
|
|
|
|
elif self.hyperparams['edge_state_combine_method'] == 'attention':
|
|
|
|
self.add_submodule(self.node_type + '/edge_attention_combine',
|
|
|
|
model_if_absent=TemporallyBatchedAdditiveAttention(
|
|
|
|
encoder_hidden_state_dim=self.state_length,
|
|
|
|
decoder_hidden_state_dim=self.state_length))
|
|
|
|
edge_encoder_input_size = self.state_length + neighbor_state_length
|
|
|
|
|
|
|
|
else:
|
|
|
|
edge_encoder_input_size = self.state_length + neighbor_state_length
|
|
|
|
|
|
|
|
self.add_submodule(edge_type + '/edge_encoder',
|
|
|
|
model_if_absent=nn.LSTM(input_size=edge_encoder_input_size,
|
|
|
|
hidden_size=self.hyperparams['enc_rnn_dim_edge'],
|
|
|
|
batch_first=True))
|
|
|
|
|
|
|
|
def create_graphical_model(self, edge_types):
|
|
|
|
"""
|
|
|
|
Creates or queries all trainable components.
|
|
|
|
|
|
|
|
:param edge_types: List containing strings for all possible edge types for the node type.
|
|
|
|
:return: None
|
|
|
|
"""
|
|
|
|
self.clear_submodules()
|
|
|
|
|
|
|
|
############################
|
|
|
|
# Everything but Edges #
|
|
|
|
############################
|
|
|
|
self.create_node_models()
|
|
|
|
|
|
|
|
#####################
|
|
|
|
# Edge Encoders #
|
|
|
|
#####################
|
|
|
|
if self.hyperparams['edge_encoding']:
|
|
|
|
self.create_edge_models(edge_types)
|
|
|
|
|
|
|
|
for name, module in self.node_modules.items():
|
|
|
|
module.to(self.device)
|
|
|
|
|
|
|
|
def create_new_scheduler(self, name, annealer, annealer_kws, creation_condition=True):
|
|
|
|
value_scheduler = None
|
|
|
|
rsetattr(self, name + '_scheduler', value_scheduler)
|
|
|
|
if creation_condition:
|
|
|
|
annealer_kws['device'] = self.device
|
|
|
|
value_annealer = annealer(annealer_kws)
|
|
|
|
rsetattr(self, name + '_annealer', value_annealer)
|
|
|
|
|
|
|
|
# This is the value that we'll update on each call of
|
|
|
|
# step_annealers().
|
|
|
|
rsetattr(self, name, value_annealer(0).clone().detach())
|
|
|
|
dummy_optimizer = optim.Optimizer([rgetattr(self, name)], {'lr': value_annealer(0).clone().detach()})
|
|
|
|
rsetattr(self, name + '_optimizer', dummy_optimizer)
|
|
|
|
|
|
|
|
value_scheduler = CustomLR(dummy_optimizer,
|
|
|
|
value_annealer)
|
|
|
|
rsetattr(self, name + '_scheduler', value_scheduler)
|
|
|
|
|
|
|
|
self.schedulers.append(value_scheduler)
|
|
|
|
self.annealed_vars.append(name)
|
|
|
|
|
|
|
|
def set_annealing_params(self):
|
|
|
|
self.schedulers = list()
|
|
|
|
self.annealed_vars = list()
|
|
|
|
|
|
|
|
self.create_new_scheduler(name='kl_weight',
|
|
|
|
annealer=sigmoid_anneal,
|
|
|
|
annealer_kws={
|
|
|
|
'start': self.hyperparams['kl_weight_start'],
|
|
|
|
'finish': self.hyperparams['kl_weight'],
|
|
|
|
'center_step': self.hyperparams['kl_crossover'],
|
|
|
|
'steps_lo_to_hi': self.hyperparams['kl_crossover'] / self.hyperparams[
|
|
|
|
'kl_sigmoid_divisor']
|
|
|
|
})
|
|
|
|
|
|
|
|
self.create_new_scheduler(name='latent.temp',
|
|
|
|
annealer=exp_anneal,
|
|
|
|
annealer_kws={
|
|
|
|
'start': self.hyperparams['tau_init'],
|
|
|
|
'finish': self.hyperparams['tau_final'],
|
|
|
|
'rate': self.hyperparams['tau_decay_rate']
|
|
|
|
})
|
|
|
|
|
|
|
|
self.create_new_scheduler(name='latent.z_logit_clip',
|
|
|
|
annealer=sigmoid_anneal,
|
|
|
|
annealer_kws={
|
|
|
|
'start': self.hyperparams['z_logit_clip_start'],
|
|
|
|
'finish': self.hyperparams['z_logit_clip_final'],
|
|
|
|
'center_step': self.hyperparams['z_logit_clip_crossover'],
|
|
|
|
'steps_lo_to_hi': self.hyperparams['z_logit_clip_crossover'] / self.hyperparams[
|
|
|
|
'z_logit_clip_divisor']
|
|
|
|
},
|
|
|
|
creation_condition=self.hyperparams['use_z_logit_clipping'])
|
|
|
|
|
|
|
|
def step_annealers(self):
|
|
|
|
# This should manage all of the step-wise changed
|
|
|
|
# parameters automatically.
|
|
|
|
for idx, annealed_var in enumerate(self.annealed_vars):
|
|
|
|
if rgetattr(self, annealed_var + '_scheduler') is not None:
|
|
|
|
# First we step the scheduler.
|
|
|
|
with warnings.catch_warnings(): # We use a dummy optimizer: Warning because no .step() was called on it
|
|
|
|
warnings.simplefilter("ignore")
|
|
|
|
rgetattr(self, annealed_var + '_scheduler').step()
|
|
|
|
|
|
|
|
# Then we set the annealed vars' value.
|
|
|
|
rsetattr(self, annealed_var, rgetattr(self, annealed_var + '_optimizer').param_groups[0]['lr'])
|
|
|
|
|
|
|
|
self.summarize_annealers()
|
|
|
|
|
|
|
|
def summarize_annealers(self):
|
|
|
|
if self.log_writer is not None:
|
|
|
|
for annealed_var in self.annealed_vars:
|
|
|
|
if rgetattr(self, annealed_var) is not None:
|
|
|
|
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), annealed_var.replace('.', '/')),
|
|
|
|
rgetattr(self, annealed_var), self.curr_iter)
|
|
|
|
|
|
|
|
def obtain_encoded_tensors(self,
|
|
|
|
mode,
|
|
|
|
inputs,
|
|
|
|
inputs_st,
|
|
|
|
labels,
|
|
|
|
labels_st,
|
|
|
|
first_history_indices,
|
|
|
|
neighbors,
|
|
|
|
neighbors_edge_value,
|
|
|
|
robot,
|
|
|
|
map) -> (torch.Tensor,
|
2020-12-10 03:42:06 +00:00
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor,
|
|
|
|
torch.Tensor):
|
2020-04-06 01:43:49 +00:00
|
|
|
"""
|
|
|
|
Encodes input and output tensors for node and robot.
|
|
|
|
|
|
|
|
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
|
|
|
|
:param inputs: Input tensor including the state for each agent over time [bs, t, state].
|
|
|
|
:param inputs_st: Standardized input tensor.
|
|
|
|
:param labels: Label tensor including the label output for each agent over time [bs, t, pred_state].
|
|
|
|
:param labels_st: Standardized label tensor.
|
|
|
|
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
|
|
|
|
:param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time.
|
|
|
|
[[bs, t, neighbor state]]
|
|
|
|
:param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]]
|
|
|
|
:param robot: Standardized robot state over time. [bs, t, robot_state]
|
|
|
|
:param map: Tensor of Map information. [bs, channels, x, y]
|
|
|
|
:return: tuple(x, x_nr_t, y_e, y_r, y, n_s_t0)
|
|
|
|
WHERE
|
|
|
|
- x: Encoded input / condition tensor to the CVAE x_e.
|
|
|
|
- x_r_t: Robot state (if robot is in scene).
|
|
|
|
- y_e: Encoded label / future of the node.
|
|
|
|
- y_r: Encoded future of the robot.
|
|
|
|
- y: Label / future of the node.
|
|
|
|
- n_s_t0: Standardized current state of the node.
|
|
|
|
"""
|
|
|
|
|
|
|
|
x, x_r_t, y_e, y_r, y = None, None, None, None, None
|
|
|
|
initial_dynamics = dict()
|
|
|
|
|
|
|
|
batch_size = inputs.shape[0]
|
|
|
|
|
|
|
|
#########################################
|
|
|
|
# Provide basic information to encoders #
|
|
|
|
#########################################
|
|
|
|
node_history = inputs
|
|
|
|
node_present_state = inputs[:, -1]
|
|
|
|
node_pos = inputs[:, -1, 0:2]
|
|
|
|
node_vel = inputs[:, -1, 2:4]
|
|
|
|
|
|
|
|
node_history_st = inputs_st
|
|
|
|
node_present_state_st = inputs_st[:, -1]
|
|
|
|
node_pos_st = inputs_st[:, -1, 0:2]
|
|
|
|
node_vel_st = inputs_st[:, -1, 2:4]
|
|
|
|
|
|
|
|
n_s_t0 = node_present_state_st
|
|
|
|
|
|
|
|
initial_dynamics['pos'] = node_pos
|
|
|
|
initial_dynamics['vel'] = node_vel
|
|
|
|
|
|
|
|
self.dynamic.set_initial_condition(initial_dynamics)
|
|
|
|
|
|
|
|
if self.hyperparams['incl_robot_node']:
|
|
|
|
x_r_t, y_r = robot[..., 0, :], robot[..., 1:, :]
|
|
|
|
|
|
|
|
##################
|
|
|
|
# Encode History #
|
|
|
|
##################
|
|
|
|
node_history_encoded = self.encode_node_history(mode,
|
|
|
|
node_history_st,
|
|
|
|
first_history_indices)
|
|
|
|
|
|
|
|
##################
|
|
|
|
# Encode Present #
|
|
|
|
##################
|
|
|
|
node_present = node_present_state_st # [bs, state_dim]
|
|
|
|
|
|
|
|
##################
|
|
|
|
# Encode Future #
|
|
|
|
##################
|
|
|
|
if mode != ModeKeys.PREDICT:
|
|
|
|
y = labels_st
|
|
|
|
|
|
|
|
##############################
|
|
|
|
# Encode Node Edges per Type #
|
|
|
|
##############################
|
|
|
|
if self.hyperparams['edge_encoding']:
|
|
|
|
node_edges_encoded = list()
|
|
|
|
for edge_type in self.edge_types:
|
|
|
|
# Encode edges for given edge type
|
|
|
|
encoded_edges_type = self.encode_edge(mode,
|
|
|
|
node_history,
|
|
|
|
node_history_st,
|
|
|
|
edge_type,
|
|
|
|
neighbors[edge_type],
|
|
|
|
neighbors_edge_value[edge_type],
|
|
|
|
first_history_indices)
|
|
|
|
node_edges_encoded.append(encoded_edges_type) # List of [bs/nbs, enc_rnn_dim]
|
|
|
|
#####################
|
|
|
|
# Encode Node Edges #
|
|
|
|
#####################
|
|
|
|
total_edge_influence = self.encode_total_edge_influence(mode,
|
|
|
|
node_edges_encoded,
|
|
|
|
node_history_encoded,
|
|
|
|
batch_size)
|
|
|
|
|
|
|
|
################
|
|
|
|
# Map Encoding #
|
|
|
|
################
|
|
|
|
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
|
|
|
|
if self.log_writer and (self.curr_iter + 1) % 500 == 0:
|
|
|
|
map_clone = map.clone()
|
|
|
|
map_patch = self.hyperparams['map_encoder'][self.node_type]['patch_size']
|
2020-12-10 03:42:06 +00:00
|
|
|
map_clone[:, :, map_patch[1] - 5:map_patch[1] + 5, map_patch[0] - 5:map_patch[0] + 5] = 1.
|
2020-04-06 01:43:49 +00:00
|
|
|
self.log_writer.add_images(f"{self.node_type}/cropped_maps", map_clone,
|
|
|
|
self.curr_iter, dataformats='NCWH')
|
|
|
|
|
|
|
|
encoded_map = self.node_modules[self.node_type + '/map_encoder'](map * 2. - 1., (mode == ModeKeys.TRAIN))
|
|
|
|
do = self.hyperparams['map_encoder'][self.node_type]['dropout']
|
|
|
|
encoded_map = F.dropout(encoded_map, do, training=(mode == ModeKeys.TRAIN))
|
|
|
|
|
|
|
|
######################################
|
|
|
|
# Concatenate Encoder Outputs into x #
|
|
|
|
######################################
|
|
|
|
x_concat_list = list()
|
|
|
|
|
|
|
|
# Every node has an edge-influence encoder (which could just be zero).
|
|
|
|
if self.hyperparams['edge_encoding']:
|
|
|
|
x_concat_list.append(total_edge_influence) # [bs/nbs, 4*enc_rnn_dim]
|
|
|
|
|
|
|
|
# Every node has a history encoder.
|
|
|
|
x_concat_list.append(node_history_encoded) # [bs/nbs, enc_rnn_dim_history]
|
|
|
|
|
|
|
|
if self.hyperparams['incl_robot_node']:
|
|
|
|
robot_future_encoder = self.encode_robot_future(mode, x_r_t, y_r)
|
|
|
|
x_concat_list.append(robot_future_encoder)
|
|
|
|
|
|
|
|
if self.hyperparams['use_map_encoding'] and self.node_type in self.hyperparams['map_encoder']:
|
|
|
|
if self.log_writer:
|
|
|
|
self.log_writer.add_scalar(f"{self.node_type}/encoded_map_max",
|
|
|
|
torch.max(torch.abs(encoded_map)), self.curr_iter)
|
|
|
|
x_concat_list.append(encoded_map)
|
|
|
|
|
|
|
|
x = torch.cat(x_concat_list, dim=1)
|
|
|
|
|
|
|
|
if mode == ModeKeys.TRAIN or mode == ModeKeys.EVAL:
|
|
|
|
y_e = self.encode_node_future(mode, node_present, y)
|
|
|
|
|
|
|
|
return x, x_r_t, y_e, y_r, y, n_s_t0
|
|
|
|
|
|
|
|
def encode_node_history(self, mode, node_hist, first_history_indices):
|
|
|
|
"""
|
|
|
|
Encodes the nodes history.
|
|
|
|
|
|
|
|
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
|
|
|
|
:param node_hist: Historic and current state of the node. [bs, mhl, state]
|
|
|
|
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
|
|
|
|
:return: Encoded node history tensor. [bs, enc_rnn_dim]
|
|
|
|
"""
|
|
|
|
outputs, _ = run_lstm_on_variable_length_seqs(self.node_modules[self.node_type + '/node_history_encoder'],
|
|
|
|
original_seqs=node_hist,
|
|
|
|
lower_indices=first_history_indices)
|
|
|
|
|
|
|
|
outputs = F.dropout(outputs,
|
|
|
|
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
|
|
|
|
training=(mode == ModeKeys.TRAIN)) # [bs, max_time, enc_rnn_dim]
|
|
|
|
|
|
|
|
last_index_per_sequence = -(first_history_indices + 1)
|
|
|
|
|
|
|
|
return outputs[torch.arange(first_history_indices.shape[0]), last_index_per_sequence]
|
|
|
|
|
|
|
|
def encode_edge(self,
|
|
|
|
mode,
|
|
|
|
node_history,
|
|
|
|
node_history_st,
|
|
|
|
edge_type,
|
|
|
|
neighbors,
|
|
|
|
neighbors_edge_value,
|
|
|
|
first_history_indices):
|
|
|
|
|
|
|
|
max_hl = self.hyperparams['maximum_history_length']
|
|
|
|
|
|
|
|
edge_states_list = list() # list of [#of neighbors, max_ht, state_dim]
|
|
|
|
for i, neighbor_states in enumerate(neighbors): # Get neighbors for timestep in batch
|
|
|
|
if len(neighbor_states) == 0: # There are no neighbors for edge type # TODO necessary?
|
|
|
|
neighbor_state_length = int(
|
|
|
|
np.sum([len(entity_dims) for entity_dims in self.state[edge_type[1]].values()])
|
|
|
|
)
|
|
|
|
edge_states_list.append(torch.zeros((1, max_hl + 1, neighbor_state_length), device=self.device))
|
|
|
|
else:
|
|
|
|
edge_states_list.append(torch.stack(neighbor_states, dim=0).to(self.device))
|
|
|
|
|
|
|
|
if self.hyperparams['edge_state_combine_method'] == 'sum':
|
|
|
|
# Used in Structural-RNN to combine edges as well.
|
|
|
|
op_applied_edge_states_list = list()
|
|
|
|
for neighbors_state in edge_states_list:
|
|
|
|
op_applied_edge_states_list.append(torch.sum(neighbors_state, dim=0))
|
|
|
|
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
|
|
|
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
|
|
|
# Should now be (bs, time, 1)
|
|
|
|
op_applied_edge_mask_list = list()
|
|
|
|
for edge_value in neighbors_edge_value:
|
|
|
|
op_applied_edge_mask_list.append(torch.clamp(torch.sum(edge_value.to(self.device),
|
|
|
|
dim=0, keepdim=True), max=1.))
|
|
|
|
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
|
|
|
|
|
|
|
|
elif self.hyperparams['edge_state_combine_method'] == 'max':
|
|
|
|
# Used in NLP, e.g. max over word embeddings in a sentence.
|
|
|
|
op_applied_edge_states_list = list()
|
|
|
|
for neighbors_state in edge_states_list:
|
|
|
|
op_applied_edge_states_list.append(torch.max(neighbors_state, dim=0))
|
|
|
|
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
|
|
|
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
|
|
|
# Should now be (bs, time, 1)
|
|
|
|
op_applied_edge_mask_list = list()
|
|
|
|
for edge_value in neighbors_edge_value:
|
|
|
|
op_applied_edge_mask_list.append(torch.clamp(torch.max(edge_value.to(self.device),
|
|
|
|
dim=0, keepdim=True), max=1.))
|
|
|
|
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
|
|
|
|
|
|
|
|
elif self.hyperparams['edge_state_combine_method'] == 'mean':
|
|
|
|
# Used in NLP, e.g. mean over word embeddings in a sentence.
|
|
|
|
op_applied_edge_states_list = list()
|
|
|
|
for neighbors_state in edge_states_list:
|
|
|
|
op_applied_edge_states_list.append(torch.mean(neighbors_state, dim=0))
|
|
|
|
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
|
|
|
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
|
|
|
# Should now be (bs, time, 1)
|
|
|
|
op_applied_edge_mask_list = list()
|
|
|
|
for edge_value in neighbors_edge_value:
|
|
|
|
op_applied_edge_mask_list.append(torch.clamp(torch.mean(edge_value.to(self.device),
|
|
|
|
dim=0, keepdim=True), max=1.))
|
|
|
|
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
|
|
|
|
|
|
|
|
joint_history = torch.cat([combined_neighbors, node_history_st], dim=-1)
|
|
|
|
|
|
|
|
outputs, _ = run_lstm_on_variable_length_seqs(
|
|
|
|
self.node_modules[DirectedEdge.get_str_from_types(*edge_type) + '/edge_encoder'],
|
|
|
|
original_seqs=joint_history,
|
|
|
|
lower_indices=first_history_indices
|
|
|
|
)
|
|
|
|
|
|
|
|
outputs = F.dropout(outputs,
|
|
|
|
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
|
|
|
|
training=(mode == ModeKeys.TRAIN)) # [bs, max_time, enc_rnn_dim]
|
|
|
|
|
|
|
|
last_index_per_sequence = -(first_history_indices + 1)
|
|
|
|
ret = outputs[torch.arange(last_index_per_sequence.shape[0]), last_index_per_sequence]
|
|
|
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
|
|
|
return ret * combined_edge_masks
|
|
|
|
else:
|
|
|
|
return ret
|
|
|
|
|
|
|
|
def encode_total_edge_influence(self, mode, encoded_edges, node_history_encoder, batch_size):
|
|
|
|
if self.hyperparams['edge_influence_combine_method'] == 'sum':
|
|
|
|
stacked_encoded_edges = torch.stack(encoded_edges, dim=0)
|
|
|
|
combined_edges = torch.sum(stacked_encoded_edges, dim=0)
|
|
|
|
|
|
|
|
elif self.hyperparams['edge_influence_combine_method'] == 'mean':
|
|
|
|
stacked_encoded_edges = torch.stack(encoded_edges, dim=0)
|
|
|
|
combined_edges = torch.mean(stacked_encoded_edges, dim=0)
|
|
|
|
|
|
|
|
elif self.hyperparams['edge_influence_combine_method'] == 'max':
|
|
|
|
stacked_encoded_edges = torch.stack(encoded_edges, dim=0)
|
|
|
|
combined_edges = torch.max(stacked_encoded_edges, dim=0)
|
|
|
|
|
|
|
|
elif self.hyperparams['edge_influence_combine_method'] == 'bi-rnn':
|
|
|
|
if len(encoded_edges) == 0:
|
|
|
|
combined_edges = torch.zeros((batch_size, self.eie_output_dims), device=self.device)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# axis=1 because then we get size [batch_size, max_time, depth]
|
|
|
|
encoded_edges = torch.stack(encoded_edges, dim=1)
|
|
|
|
|
|
|
|
_, state = self.node_modules[self.node_type + '/edge_influence_encoder'](encoded_edges)
|
|
|
|
combined_edges = unpack_RNN_state(state)
|
|
|
|
combined_edges = F.dropout(combined_edges,
|
|
|
|
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
|
|
|
|
training=(mode == ModeKeys.TRAIN))
|
|
|
|
|
|
|
|
elif self.hyperparams['edge_influence_combine_method'] == 'attention':
|
|
|
|
# Used in Social Attention (https://arxiv.org/abs/1710.04689)
|
|
|
|
if len(encoded_edges) == 0:
|
|
|
|
combined_edges = torch.zeros((batch_size, self.eie_output_dims), device=self.device)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# axis=1 because then we get size [batch_size, max_time, depth]
|
|
|
|
encoded_edges = torch.stack(encoded_edges, dim=1)
|
|
|
|
combined_edges, _ = self.node_modules[self.node_type + '/edge_influence_encoder'](encoded_edges,
|
|
|
|
node_history_encoder)
|
|
|
|
combined_edges = F.dropout(combined_edges,
|
|
|
|
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
|
|
|
|
training=(mode == ModeKeys.TRAIN))
|
|
|
|
|
|
|
|
return combined_edges
|
|
|
|
|
|
|
|
def encode_node_future(self, mode, node_present, node_future) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Encodes the node future (during training) using a bi-directional LSTM
|
|
|
|
|
|
|
|
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
|
|
|
|
:param node_present: Current state of the node. [bs, state]
|
|
|
|
:param node_future: Future states of the node. [bs, ph, state]
|
|
|
|
:return: Encoded future.
|
|
|
|
"""
|
|
|
|
initial_h_model = self.node_modules[self.node_type + '/node_future_encoder/initial_h']
|
|
|
|
initial_c_model = self.node_modules[self.node_type + '/node_future_encoder/initial_c']
|
|
|
|
|
|
|
|
# Here we're initializing the forward hidden states,
|
|
|
|
# but zeroing the backward ones.
|
|
|
|
initial_h = initial_h_model(node_present)
|
|
|
|
initial_h = torch.stack([initial_h, torch.zeros_like(initial_h, device=self.device)], dim=0)
|
|
|
|
|
|
|
|
initial_c = initial_c_model(node_present)
|
|
|
|
initial_c = torch.stack([initial_c, torch.zeros_like(initial_c, device=self.device)], dim=0)
|
|
|
|
|
|
|
|
initial_state = (initial_h, initial_c)
|
|
|
|
|
|
|
|
_, state = self.node_modules[self.node_type + '/node_future_encoder'](node_future, initial_state)
|
|
|
|
state = unpack_RNN_state(state)
|
|
|
|
state = F.dropout(state,
|
|
|
|
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
|
|
|
|
training=(mode == ModeKeys.TRAIN))
|
|
|
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
def encode_robot_future(self, mode, robot_present, robot_future) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Encodes the robot future (during training) using a bi-directional LSTM
|
|
|
|
|
|
|
|
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
|
|
|
|
:param robot_present: Current state of the robot. [bs, state]
|
|
|
|
:param robot_future: Future states of the robot. [bs, ph, state]
|
|
|
|
:return: Encoded future.
|
|
|
|
"""
|
|
|
|
initial_h_model = self.node_modules['robot_future_encoder/initial_h']
|
|
|
|
initial_c_model = self.node_modules['robot_future_encoder/initial_c']
|
|
|
|
|
|
|
|
# Here we're initializing the forward hidden states,
|
|
|
|
# but zeroing the backward ones.
|
|
|
|
initial_h = initial_h_model(robot_present)
|
|
|
|
initial_h = torch.stack([initial_h, torch.zeros_like(initial_h, device=self.device)], dim=0)
|
|
|
|
|
|
|
|
initial_c = initial_c_model(robot_present)
|
|
|
|
initial_c = torch.stack([initial_c, torch.zeros_like(initial_c, device=self.device)], dim=0)
|
|
|
|
|
|
|
|
initial_state = (initial_h, initial_c)
|
|
|
|
|
|
|
|
_, state = self.node_modules['robot_future_encoder'](robot_future, initial_state)
|
|
|
|
state = unpack_RNN_state(state)
|
|
|
|
state = F.dropout(state,
|
|
|
|
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
|
|
|
|
training=(mode == ModeKeys.TRAIN))
|
|
|
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
def q_z_xy(self, mode, x, y_e) -> torch.Tensor:
|
|
|
|
r"""
|
|
|
|
.. math:: q_\phi(z \mid \mathbf{x}_i, \mathbf{y}_i)
|
|
|
|
|
|
|
|
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
|
|
|
|
:param x: Input / Condition tensor.
|
|
|
|
:param y_e: Encoded future tensor.
|
|
|
|
:return: Latent distribution of the CVAE.
|
|
|
|
"""
|
|
|
|
xy = torch.cat([x, y_e], dim=1)
|
|
|
|
|
|
|
|
if self.hyperparams['q_z_xy_MLP_dims'] is not None:
|
|
|
|
dense = self.node_modules[self.node_type + '/q_z_xy']
|
|
|
|
h = F.dropout(F.relu(dense(xy)),
|
|
|
|
p=1. - self.hyperparams['MLP_dropout_keep_prob'],
|
|
|
|
training=(mode == ModeKeys.TRAIN))
|
|
|
|
|
|
|
|
else:
|
|
|
|
h = xy
|
|
|
|
|
|
|
|
to_latent = self.node_modules[self.node_type + '/hxy_to_z']
|
|
|
|
return self.latent.dist_from_h(to_latent(h), mode)
|
|
|
|
|
|
|
|
def p_z_x(self, mode, x):
|
|
|
|
r"""
|
|
|
|
.. math:: p_\theta(z \mid \mathbf{x}_i)
|
|
|
|
|
|
|
|
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
|
|
|
|
:param x: Input / Condition tensor.
|
|
|
|
:return: Latent distribution of the CVAE.
|
|
|
|
"""
|
|
|
|
if self.hyperparams['p_z_x_MLP_dims'] is not None:
|
|
|
|
dense = self.node_modules[self.node_type + '/p_z_x']
|
|
|
|
h = F.dropout(F.relu(dense(x)),
|
|
|
|
p=1. - self.hyperparams['MLP_dropout_keep_prob'],
|
|
|
|
training=(mode == ModeKeys.TRAIN))
|
|
|
|
|
|
|
|
else:
|
|
|
|
h = x
|
|
|
|
|
|
|
|
to_latent = self.node_modules[self.node_type + '/hx_to_z']
|
|
|
|
return self.latent.dist_from_h(to_latent(h), mode)
|
|
|
|
|
|
|
|
def project_to_GMM_params(self, tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor):
|
|
|
|
"""
|
|
|
|
Projects tensor to parameters of a GMM with N components and D dimensions.
|
|
|
|
|
|
|
|
:param tensor: Input tensor.
|
|
|
|
:return: tuple(log_pis, mus, log_sigmas, corrs)
|
|
|
|
WHERE
|
|
|
|
- log_pis: Weight (logarithm) of each GMM component. [N]
|
|
|
|
- mus: Mean of each GMM component. [N, D]
|
|
|
|
- log_sigmas: Standard Deviation (logarithm) of each GMM component. [N, D]
|
|
|
|
- corrs: Correlation between the GMM components. [N]
|
|
|
|
"""
|
|
|
|
log_pis = self.node_modules[self.node_type + '/decoder/proj_to_GMM_log_pis'](tensor)
|
|
|
|
mus = self.node_modules[self.node_type + '/decoder/proj_to_GMM_mus'](tensor)
|
|
|
|
log_sigmas = self.node_modules[self.node_type + '/decoder/proj_to_GMM_log_sigmas'](tensor)
|
|
|
|
corrs = torch.tanh(self.node_modules[self.node_type + '/decoder/proj_to_GMM_corrs'](tensor))
|
|
|
|
return log_pis, mus, log_sigmas, corrs
|
|
|
|
|
|
|
|
def p_y_xz(self, mode, x, x_nr_t, y_r, n_s_t0, z_stacked, prediction_horizon,
|
|
|
|
num_samples, num_components=1, gmm_mode=False):
|
|
|
|
r"""
|
|
|
|
.. math:: p_\psi(\mathbf{y}_i \mid \mathbf{x}_i, z)
|
|
|
|
|
|
|
|
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
|
|
|
|
:param x: Input / Condition tensor.
|
|
|
|
:param x_nr_t: Joint state of node and robot (if robot is in scene).
|
|
|
|
:param y: Future tensor.
|
|
|
|
:param y_r: Encoded future tensor.
|
|
|
|
:param n_s_t0: Standardized current state of the node.
|
|
|
|
:param z_stacked: Stacked latent state. [num_samples_z * num_samples_gmm, bs, latent_state]
|
|
|
|
:param prediction_horizon: Number of prediction timesteps.
|
|
|
|
:param num_samples: Number of samples from the latent space.
|
|
|
|
:param num_components: Number of GMM components.
|
|
|
|
:param gmm_mode: If True: The mode of the GMM is sampled.
|
|
|
|
:return: GMM2D. If mode is Predict, also samples from the GMM.
|
|
|
|
"""
|
|
|
|
ph = prediction_horizon
|
|
|
|
pred_dim = self.pred_state_length
|
|
|
|
|
|
|
|
z = torch.reshape(z_stacked, (-1, self.latent.z_dim))
|
|
|
|
zx = torch.cat([z, x.repeat(num_samples * num_components, 1)], dim=1)
|
|
|
|
|
|
|
|
cell = self.node_modules[self.node_type + '/decoder/rnn_cell']
|
|
|
|
initial_h_model = self.node_modules[self.node_type + '/decoder/initial_h']
|
|
|
|
|
|
|
|
initial_state = initial_h_model(zx)
|
|
|
|
|
|
|
|
log_pis, mus, log_sigmas, corrs, a_sample = [], [], [], [], []
|
|
|
|
|
|
|
|
# Infer initial action state for node from current state
|
|
|
|
a_0 = self.node_modules[self.node_type + '/decoder/state_action'](n_s_t0)
|
|
|
|
|
|
|
|
state = initial_state
|
|
|
|
if self.hyperparams['incl_robot_node']:
|
|
|
|
input_ = torch.cat([zx,
|
2020-12-10 03:42:06 +00:00
|
|
|
a_0.repeat(num_samples * num_components, 1),
|
|
|
|
x_nr_t.repeat(num_samples * num_components, 1)], dim=1)
|
2020-04-06 01:43:49 +00:00
|
|
|
else:
|
2020-12-10 03:42:06 +00:00
|
|
|
input_ = torch.cat([zx, a_0.repeat(num_samples * num_components, 1)], dim=1)
|
2020-04-06 01:43:49 +00:00
|
|
|
|
|
|
|
for j in range(ph):
|
|
|
|
h_state = cell(input_, state)
|
|
|
|
log_pi_t, mu_t, log_sigma_t, corr_t = self.project_to_GMM_params(h_state)
|
|
|
|
|
|
|
|
gmm = GMM2D(log_pi_t, mu_t, log_sigma_t, corr_t) # [k;bs, pred_dim]
|
|
|
|
|
|
|
|
if mode == ModeKeys.PREDICT and gmm_mode:
|
|
|
|
a_t = gmm.mode()
|
|
|
|
else:
|
|
|
|
a_t = gmm.rsample()
|
|
|
|
|
|
|
|
if num_components > 1:
|
|
|
|
if mode == ModeKeys.PREDICT:
|
|
|
|
log_pis.append(self.latent.p_dist.logits.repeat(num_samples, 1, 1))
|
|
|
|
else:
|
|
|
|
log_pis.append(self.latent.q_dist.logits.repeat(num_samples, 1, 1))
|
|
|
|
else:
|
|
|
|
log_pis.append(
|
|
|
|
torch.ones_like(corr_t.reshape(num_samples, num_components, -1).permute(0, 2, 1).reshape(-1, 1))
|
|
|
|
)
|
|
|
|
|
|
|
|
mus.append(
|
|
|
|
mu_t.reshape(
|
|
|
|
num_samples, num_components, -1, 2
|
|
|
|
).permute(0, 2, 1, 3).reshape(-1, 2 * num_components)
|
|
|
|
)
|
|
|
|
log_sigmas.append(
|
|
|
|
log_sigma_t.reshape(
|
|
|
|
num_samples, num_components, -1, 2
|
|
|
|
).permute(0, 2, 1, 3).reshape(-1, 2 * num_components))
|
|
|
|
corrs.append(
|
|
|
|
corr_t.reshape(
|
|
|
|
num_samples, num_components, -1
|
|
|
|
).permute(0, 2, 1).reshape(-1, num_components))
|
|
|
|
|
|
|
|
if self.hyperparams['incl_robot_node']:
|
|
|
|
dec_inputs = [zx, a_t, y_r[:, j].repeat(num_samples * num_components, 1)]
|
|
|
|
else:
|
|
|
|
dec_inputs = [zx, a_t]
|
|
|
|
input_ = torch.cat(dec_inputs, dim=1)
|
|
|
|
state = h_state
|
|
|
|
|
|
|
|
log_pis = torch.stack(log_pis, dim=1)
|
|
|
|
mus = torch.stack(mus, dim=1)
|
|
|
|
log_sigmas = torch.stack(log_sigmas, dim=1)
|
|
|
|
corrs = torch.stack(corrs, dim=1)
|
|
|
|
|
|
|
|
a_dist = GMM2D(torch.reshape(log_pis, [num_samples, -1, ph, num_components]),
|
|
|
|
torch.reshape(mus, [num_samples, -1, ph, num_components * pred_dim]),
|
|
|
|
torch.reshape(log_sigmas, [num_samples, -1, ph, num_components * pred_dim]),
|
|
|
|
torch.reshape(corrs, [num_samples, -1, ph, num_components]))
|
|
|
|
|
|
|
|
if self.hyperparams['dynamic'][self.node_type]['distribution']:
|
|
|
|
y_dist = self.dynamic.integrate_distribution(a_dist, x)
|
|
|
|
else:
|
|
|
|
y_dist = a_dist
|
|
|
|
|
|
|
|
if mode == ModeKeys.PREDICT:
|
|
|
|
if gmm_mode:
|
|
|
|
a_sample = a_dist.mode()
|
|
|
|
else:
|
|
|
|
a_sample = a_dist.rsample()
|
|
|
|
sampled_future = self.dynamic.integrate_samples(a_sample, x)
|
|
|
|
return y_dist, sampled_future
|
|
|
|
else:
|
|
|
|
return y_dist
|
|
|
|
|
|
|
|
def encoder(self, mode, x, y_e, num_samples=None):
|
|
|
|
"""
|
|
|
|
Encoder of the CVAE.
|
|
|
|
|
|
|
|
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
|
|
|
|
:param x: Input / Condition tensor.
|
|
|
|
:param y_e: Encoded future tensor.
|
|
|
|
:param num_samples: Number of samples from the latent space during Prediction.
|
|
|
|
:return: tuple(z, kl_obj)
|
|
|
|
WHERE
|
|
|
|
- z: Samples from the latent space.
|
|
|
|
- kl_obj: KL Divergenze between q and p
|
|
|
|
"""
|
|
|
|
if mode == ModeKeys.TRAIN:
|
|
|
|
sample_ct = self.hyperparams['k']
|
|
|
|
elif mode == ModeKeys.EVAL:
|
|
|
|
sample_ct = self.hyperparams['k_eval']
|
|
|
|
elif mode == ModeKeys.PREDICT:
|
|
|
|
sample_ct = num_samples
|
|
|
|
if num_samples is None:
|
|
|
|
raise ValueError("num_samples cannot be None with mode == PREDICT.")
|
|
|
|
|
|
|
|
self.latent.q_dist = self.q_z_xy(mode, x, y_e)
|
|
|
|
self.latent.p_dist = self.p_z_x(mode, x)
|
|
|
|
|
|
|
|
z = self.latent.sample_q(sample_ct, mode)
|
|
|
|
|
|
|
|
if mode == ModeKeys.TRAIN:
|
|
|
|
kl_obj = self.latent.kl_q_p(self.log_writer, '%s' % str(self.node_type), self.curr_iter)
|
|
|
|
if self.log_writer is not None:
|
|
|
|
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'kl'), kl_obj, self.curr_iter)
|
|
|
|
else:
|
|
|
|
kl_obj = None
|
|
|
|
|
|
|
|
return z, kl_obj
|
|
|
|
|
|
|
|
def decoder(self, mode, x, x_nr_t, y, y_r, n_s_t0, z, labels, prediction_horizon, num_samples):
|
|
|
|
"""
|
|
|
|
Decoder of the CVAE.
|
|
|
|
|
|
|
|
:param mode: Mode in which the model is operated. E.g. Train, Eval, Predict.
|
|
|
|
:param x: Input / Condition tensor.
|
|
|
|
:param x: Input / Condition tensor.
|
|
|
|
:param x_nr_t: Joint state of node and robot (if robot is in scene).
|
|
|
|
:param y: Future tensor.
|
|
|
|
:param y_r: Encoded future tensor.
|
|
|
|
:param n_s_t0: Standardized current state of the node.
|
|
|
|
:param z: Stacked latent state.
|
|
|
|
:param prediction_horizon: Number of prediction timesteps.
|
|
|
|
:param num_samples: Number of samples from the latent space.
|
|
|
|
:return: Log probability of y over p.
|
|
|
|
"""
|
|
|
|
|
|
|
|
num_components = self.hyperparams['N'] * self.hyperparams['K']
|
|
|
|
y_dist = self.p_y_xz(mode, x, x_nr_t, y_r, n_s_t0, z,
|
|
|
|
prediction_horizon, num_samples, num_components=num_components)
|
|
|
|
log_p_yt_xz = torch.clamp(y_dist.log_prob(labels), max=self.hyperparams['log_p_yt_xz_max'])
|
|
|
|
if self.hyperparams['log_histograms'] and self.log_writer is not None:
|
|
|
|
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'log_p_yt_xz'), log_p_yt_xz, self.curr_iter)
|
|
|
|
|
|
|
|
log_p_y_xz = torch.sum(log_p_yt_xz, dim=2)
|
|
|
|
return log_p_y_xz
|
|
|
|
|
|
|
|
def train_loss(self,
|
|
|
|
inputs,
|
|
|
|
inputs_st,
|
|
|
|
first_history_indices,
|
|
|
|
labels,
|
|
|
|
labels_st,
|
|
|
|
neighbors,
|
|
|
|
neighbors_edge_value,
|
|
|
|
robot,
|
|
|
|
map,
|
|
|
|
prediction_horizon) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Calculates the training loss for a batch.
|
|
|
|
|
|
|
|
:param inputs: Input tensor including the state for each agent over time [bs, t, state].
|
|
|
|
:param inputs_st: Standardized input tensor.
|
|
|
|
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
|
|
|
|
:param labels: Label tensor including the label output for each agent over time [bs, t, pred_state].
|
|
|
|
:param labels_st: Standardized label tensor.
|
|
|
|
:param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time.
|
|
|
|
[[bs, t, neighbor state]]
|
|
|
|
:param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]]
|
|
|
|
:param robot: Standardized robot state over time. [bs, t, robot_state]
|
|
|
|
:param map: Tensor of Map information. [bs, channels, x, y]
|
|
|
|
:param prediction_horizon: Number of prediction timesteps.
|
|
|
|
:return: Scalar tensor -> nll loss
|
|
|
|
"""
|
|
|
|
mode = ModeKeys.TRAIN
|
|
|
|
|
|
|
|
x, x_nr_t, y_e, y_r, y, n_s_t0 = self.obtain_encoded_tensors(mode=mode,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs_st=inputs_st,
|
|
|
|
labels=labels,
|
|
|
|
labels_st=labels_st,
|
|
|
|
first_history_indices=first_history_indices,
|
|
|
|
neighbors=neighbors,
|
|
|
|
neighbors_edge_value=neighbors_edge_value,
|
|
|
|
robot=robot,
|
|
|
|
map=map)
|
|
|
|
|
|
|
|
z, kl = self.encoder(mode, x, y_e)
|
|
|
|
log_p_y_xz = self.decoder(mode, x, x_nr_t, y, y_r, n_s_t0, z,
|
2020-12-10 03:42:06 +00:00
|
|
|
labels, # Loss is calculated on unstandardized label
|
2020-04-06 01:43:49 +00:00
|
|
|
prediction_horizon,
|
|
|
|
self.hyperparams['k'])
|
|
|
|
|
|
|
|
log_p_y_xz_mean = torch.mean(log_p_y_xz, dim=0) # [nbs]
|
|
|
|
log_likelihood = torch.mean(log_p_y_xz_mean)
|
|
|
|
|
|
|
|
mutual_inf_q = mutual_inf_mc(self.latent.q_dist)
|
|
|
|
mutual_inf_p = mutual_inf_mc(self.latent.p_dist)
|
|
|
|
|
|
|
|
ELBO = log_likelihood - self.kl_weight * kl + 1. * mutual_inf_p
|
|
|
|
loss = -ELBO
|
|
|
|
|
|
|
|
if self.hyperparams['log_histograms'] and self.log_writer is not None:
|
|
|
|
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'log_p_y_xz'),
|
|
|
|
log_p_y_xz_mean,
|
|
|
|
self.curr_iter)
|
|
|
|
|
|
|
|
if self.log_writer is not None:
|
|
|
|
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'mutual_information_q'),
|
|
|
|
mutual_inf_q,
|
|
|
|
self.curr_iter)
|
|
|
|
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'mutual_information_p'),
|
|
|
|
mutual_inf_p,
|
|
|
|
self.curr_iter)
|
|
|
|
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'log_likelihood'),
|
|
|
|
log_likelihood,
|
|
|
|
self.curr_iter)
|
|
|
|
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'loss'),
|
|
|
|
loss,
|
|
|
|
self.curr_iter)
|
|
|
|
if self.hyperparams['log_histograms']:
|
|
|
|
self.latent.summarize_for_tensorboard(self.log_writer, str(self.node_type), self.curr_iter)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def eval_loss(self,
|
|
|
|
inputs,
|
|
|
|
inputs_st,
|
|
|
|
first_history_indices,
|
|
|
|
labels,
|
|
|
|
labels_st,
|
|
|
|
neighbors,
|
|
|
|
neighbors_edge_value,
|
|
|
|
robot,
|
|
|
|
map,
|
|
|
|
prediction_horizon) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Calculates the evaluation loss for a batch.
|
|
|
|
|
|
|
|
:param inputs: Input tensor including the state for each agent over time [bs, t, state].
|
|
|
|
:param inputs_st: Standardized input tensor.
|
|
|
|
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
|
|
|
|
:param labels: Label tensor including the label output for each agent over time [bs, t, pred_state].
|
|
|
|
:param labels_st: Standardized label tensor.
|
|
|
|
:param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time.
|
|
|
|
[[bs, t, neighbor state]]
|
|
|
|
:param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]]
|
|
|
|
:param robot: Standardized robot state over time. [bs, t, robot_state]
|
|
|
|
:param map: Tensor of Map information. [bs, channels, x, y]
|
|
|
|
:param prediction_horizon: Number of prediction timesteps.
|
|
|
|
:return: tuple(nll_q_is, nll_p, nll_exact, nll_sampled)
|
|
|
|
"""
|
|
|
|
|
|
|
|
mode = ModeKeys.EVAL
|
|
|
|
|
|
|
|
x, x_nr_t, y_e, y_r, y, n_s_t0 = self.obtain_encoded_tensors(mode=mode,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs_st=inputs_st,
|
|
|
|
labels=labels,
|
|
|
|
labels_st=labels_st,
|
|
|
|
first_history_indices=first_history_indices,
|
|
|
|
neighbors=neighbors,
|
|
|
|
neighbors_edge_value=neighbors_edge_value,
|
|
|
|
robot=robot,
|
|
|
|
map=map)
|
|
|
|
|
|
|
|
num_components = self.hyperparams['N'] * self.hyperparams['K']
|
|
|
|
### Importance sampled NLL estimate
|
|
|
|
z, _ = self.encoder(mode, x, y_e) # [k_eval, nbs, N*K]
|
|
|
|
z = self.latent.sample_p(1, mode, full_dist=True)
|
|
|
|
y_dist, _ = self.p_y_xz(ModeKeys.PREDICT, x, x_nr_t, y_r, n_s_t0, z,
|
|
|
|
prediction_horizon, num_samples=1, num_components=num_components)
|
|
|
|
# We use unstandardized labels to compute the loss
|
|
|
|
log_p_yt_xz = torch.clamp(y_dist.log_prob(labels), max=self.hyperparams['log_p_yt_xz_max'])
|
|
|
|
log_p_y_xz = torch.sum(log_p_yt_xz, dim=2)
|
|
|
|
log_p_y_xz_mean = torch.mean(log_p_y_xz, dim=0) # [nbs]
|
|
|
|
log_likelihood = torch.mean(log_p_y_xz_mean)
|
|
|
|
nll = -log_likelihood
|
|
|
|
|
|
|
|
return nll
|
|
|
|
|
|
|
|
def predict(self,
|
|
|
|
inputs,
|
|
|
|
inputs_st,
|
|
|
|
first_history_indices,
|
|
|
|
neighbors,
|
|
|
|
neighbors_edge_value,
|
|
|
|
robot,
|
|
|
|
map,
|
|
|
|
prediction_horizon,
|
|
|
|
num_samples,
|
|
|
|
z_mode=False,
|
|
|
|
gmm_mode=False,
|
|
|
|
full_dist=True,
|
|
|
|
all_z_sep=False):
|
|
|
|
"""
|
|
|
|
Predicts the future of a batch of nodes.
|
|
|
|
|
|
|
|
:param inputs: Input tensor including the state for each agent over time [bs, t, state].
|
|
|
|
:param inputs_st: Standardized input tensor.
|
|
|
|
:param first_history_indices: First timestep (index) in scene for which data is available for a node [bs]
|
|
|
|
:param neighbors: Preprocessed dict (indexed by edge type) of list of neighbor states over time.
|
|
|
|
[[bs, t, neighbor state]]
|
|
|
|
:param neighbors_edge_value: Preprocessed edge values for all neighbor nodes [[N]]
|
|
|
|
:param robot: Standardized robot state over time. [bs, t, robot_state]
|
|
|
|
:param map: Tensor of Map information. [bs, channels, x, y]
|
|
|
|
:param prediction_horizon: Number of prediction timesteps.
|
|
|
|
:param num_samples: Number of samples from the latent space.
|
|
|
|
:param z_mode: If True: Select the most likely latent state.
|
|
|
|
:param gmm_mode: If True: The mode of the GMM is sampled.
|
|
|
|
:param all_z_sep: Samples each latent mode individually without merging them into a GMM.
|
|
|
|
:param full_dist: Samples all latent states and merges them into a GMM as output.
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
mode = ModeKeys.PREDICT
|
|
|
|
|
|
|
|
x, x_nr_t, _, y_r, _, n_s_t0 = self.obtain_encoded_tensors(mode=mode,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs_st=inputs_st,
|
|
|
|
labels=None,
|
|
|
|
labels_st=None,
|
|
|
|
first_history_indices=first_history_indices,
|
|
|
|
neighbors=neighbors,
|
|
|
|
neighbors_edge_value=neighbors_edge_value,
|
|
|
|
robot=robot,
|
|
|
|
map=map)
|
|
|
|
|
|
|
|
self.latent.p_dist = self.p_z_x(mode, x)
|
|
|
|
z, num_samples, num_components = self.latent.sample_p(num_samples,
|
|
|
|
mode,
|
|
|
|
most_likely_z=z_mode,
|
|
|
|
full_dist=full_dist,
|
|
|
|
all_z_sep=all_z_sep)
|
|
|
|
|
|
|
|
_, our_sampled_future = self.p_y_xz(mode, x, x_nr_t, y_r, n_s_t0, z,
|
|
|
|
prediction_horizon,
|
|
|
|
num_samples,
|
|
|
|
num_components,
|
|
|
|
gmm_mode)
|
|
|
|
|
|
|
|
return our_sampled_future
|