Trajectron-plus-plus/code/model/node_model.py

1201 lines
62 KiB
Python

import warnings
import torch.distributions as td
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from model.components import *
from model.model_utils import *
class MultimodalGenerativeCVAE(object):
def __init__(self,
env,
node_type,
model_registrar,
hyperparams,
device,
edge_types,
log_writer=None):
self.env = env
self.node_type = node_type.name
self.model_registrar = model_registrar
self.log_writer = log_writer
self.device = device
self.edge_types = [edge_type for edge_type in edge_types if edge_type[0] is node_type]
self.curr_iter = 0
self.node_modules = dict()
self.hyperparams = hyperparams
self.min_hl = self.hyperparams['minimum_history_length']
self.max_hl = self.hyperparams['maximum_history_length']
self.ph = self.hyperparams['prediction_horizon']
self.state = self.hyperparams['state']
self.pred_state = self.hyperparams['pred_state'][node_type.name]
self.state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[node_type.name].values()])) * 2 # We have the relative and absolute state
self.robot_state_length = int(np.sum([len(entity_dims) for entity_dims in self.state['VEHICLE'].values()])) # TODO VEHICLE is hard coded for now
self.pred_state_length = int(np.sum([len(entity_dims) for entity_dims in self.pred_state.values()]))
edge_types_str = [env.edge_type_str(edge_type) for edge_type in self.edge_types]
self.create_graphical_model(edge_types_str)
def set_curr_iter(self, curr_iter):
self.curr_iter = curr_iter
def add_submodule(self, name, model_if_absent):
self.node_modules[name] = self.model_registrar.get_model(name, model_if_absent)
def clear_submodules(self):
self.node_modules.clear()
def create_graphical_model(self, edge_types):
self.clear_submodules()
############################
# Node History Encoder #
############################
self.add_submodule(self.node_type + '/node_history_encoder',
model_if_absent=nn.LSTM(input_size=self.state_length,
hidden_size=self.hyperparams['enc_rnn_dim_history'],
batch_first=True))
###########################
# Node Future Encoder #
###########################
# We'll create this here, but then later check if in training mode.
# Based on that, we'll factor this into the computation graph (or not).
self.add_submodule(self.node_type + '/node_future_encoder',
model_if_absent=nn.LSTM(input_size=self.pred_state_length,
hidden_size=self.hyperparams['enc_rnn_dim_future'],
bidirectional=True,
batch_first=True))
# These are related to how you initialize states for the node future encoder.
self.add_submodule(self.node_type + '/node_future_encoder/initial_h',
model_if_absent=nn.Linear(self.state_length,
self.hyperparams['enc_rnn_dim_future']))
self.add_submodule(self.node_type + '/node_future_encoder/initial_c',
model_if_absent=nn.Linear(self.state_length,
self.hyperparams['enc_rnn_dim_future']))
############################
# Robot Future Encoder #
############################
# We'll create this here, but then later check if we're next to the robot.
# Based on that, we'll factor this into the computation graph (or not).
if self.hyperparams['incl_robot_node']:
self.add_submodule('robot_future_encoder',
model_if_absent=nn.LSTM(input_size=self.robot_state_length ,
hidden_size=self.hyperparams['enc_rnn_dim_future'],
bidirectional=True,
batch_first=True))
# These are related to how you initialize states for the robot future encoder.
self.add_submodule('robot_future_encoder/initial_h',
model_if_absent=nn.Linear(self.robot_state_length,
self.hyperparams['enc_rnn_dim_future']))
self.add_submodule('robot_future_encoder/initial_c',
model_if_absent=nn.Linear(self.robot_state_length,
self.hyperparams['enc_rnn_dim_future']))
#####################
# Edge Encoders #
#####################
# print('create_graphical_model', self.node)
# print('create_graphical_model', self.neighbors_via_edge_type)
for edge_type in edge_types:
neighbor_state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[edge_type.split('-')[1]].values()]))
if self.hyperparams['edge_state_combine_method'] == 'pointnet':
self.add_submodule(edge_type + '/pointnet_encoder',
model_if_absent=nn.Sequential(
nn.Linear(self.state_length, 2 * self.state_length),
nn.ReLU(),
nn.Linear(2 * self.state_length, 2 * self.state_length),
nn.ReLU()))
edge_encoder_input_size = 2 * self.state_length + self.state_length
elif self.hyperparams['edge_state_combine_method'] == 'attention':
self.add_submodule(self.node_type + '/edge_attention_combine',
model_if_absent=TemporallyBatchedAdditiveAttention(
encoder_hidden_state_dim=self.state_length,
decoder_hidden_state_dim=self.state_length))
edge_encoder_input_size = self.state_length + neighbor_state_length
else:
edge_encoder_input_size = self.state_length + neighbor_state_length
self.add_submodule(edge_type + '/edge_encoder',
model_if_absent=nn.LSTM(input_size=edge_encoder_input_size,
hidden_size=self.hyperparams['enc_rnn_dim_edge'],
batch_first=True))
##############################
# Edge Influence Encoder #
##############################
# NOTE: The edge influence encoding happens during calls
# to forward or incremental_forward, so we don't create
# a model for it here for the max and sum variants.
if self.hyperparams['edge_influence_combine_method'] == 'bi-rnn':
self.add_submodule(self.node_type + '/edge_influence_encoder',
model_if_absent=nn.LSTM(input_size=self.hyperparams['enc_rnn_dim_edge'],
hidden_size=self.hyperparams['enc_rnn_dim_edge_influence'],
bidirectional=True,
batch_first=True))
# Four times because we're trying to mimic a bi-directional
# LSTM's output (which, here, is c and h from both ends).
self.eie_output_dims = 4 * self.hyperparams['enc_rnn_dim_edge_influence']
elif self.hyperparams['edge_influence_combine_method'] == 'attention':
# Chose additive attention because of https://arxiv.org/pdf/1703.03906.pdf
# We calculate an attention context vector using the encoded edges as the "encoder" (that we attend _over_)
# and the node history encoder representation as the "decoder state" (that we attend _on_).
self.add_submodule(self.node_type + '/edge_influence_encoder',
model_if_absent=AdditiveAttention(
encoder_hidden_state_dim=self.hyperparams['enc_rnn_dim_edge_influence'],
decoder_hidden_state_dim=self.hyperparams['enc_rnn_dim_history']))
self.eie_output_dims = self.hyperparams['enc_rnn_dim_edge_influence']
###################
# Map Encoder #
###################
if self.hyperparams['use_map_encoding']:
self.add_submodule(self.node_type + '/map_encoder',
model_if_absent=CNNMapEncoder(input_size=self.hyperparams['map_context'],
hidden_size=self.hyperparams['map_enc_hidden_size'],
output_size=self.hyperparams['map_enc_output_size']))
################################
# Discrete Latent Variable #
################################
self.latent = DiscreteLatent(self.hyperparams, self.device)
######################################################################
# Various Fully-Connected Layers from Encoder to Latent Variable #
######################################################################
# Edge Influence Encoder Node History Encoder
x_size = self.eie_output_dims + self.hyperparams['enc_rnn_dim_history']
if self.hyperparams['incl_robot_node']:
# Future Conditional Encoder
x_size += 4 * self.hyperparams['enc_rnn_dim_future']
if self.hyperparams['use_map_encoding']:
# Map Encoder
x_size += self.hyperparams['map_enc_output_size']
z_size = self.hyperparams['N'] * self.hyperparams['K']
if self.hyperparams['p_z_x_MLP_dims'] is not None:
self.add_submodule(self.node_type + '/p_z_x',
model_if_absent=nn.Linear(x_size, self.hyperparams['p_z_x_MLP_dims']))
hx_size = self.hyperparams['p_z_x_MLP_dims']
else:
hx_size = x_size
self.add_submodule(self.node_type + '/hx_to_z',
model_if_absent=nn.Linear(hx_size, self.latent.z_dim))
if self.hyperparams['q_z_xy_MLP_dims'] is not None:
self.add_submodule(self.node_type + '/q_z_xy',
# Node Future Encoder
model_if_absent=nn.Linear(x_size + 4 * self.hyperparams['enc_rnn_dim_future'],
self.hyperparams['q_z_xy_MLP_dims']))
hxy_size = self.hyperparams['q_z_xy_MLP_dims']
else:
# Node Future Encoder
hxy_size = x_size + 4 * self.hyperparams['enc_rnn_dim_future']
self.add_submodule(self.node_type + '/hxy_to_z',
model_if_absent=nn.Linear(hxy_size, self.latent.z_dim))
####################
# Decoder LSTM #
####################
if self.hyperparams['incl_robot_node']:
decoder_input_dims = self.pred_state_length + self.robot_state_length + z_size + x_size
else:
decoder_input_dims = self.pred_state_length + z_size + x_size
self.add_submodule(self.node_type + '/decoder/lstm_cell',
model_if_absent=nn.GRUCell(decoder_input_dims, self.hyperparams['dec_rnn_dim']))
self.add_submodule(self.node_type + '/decoder/initial_h',
model_if_absent=nn.Linear(z_size + x_size, self.hyperparams['dec_rnn_dim']))
self.add_submodule(self.node_type + '/decoder/initial_c',
model_if_absent=nn.Linear(z_size + x_size, self.hyperparams['dec_rnn_dim']))
###################
# Decoder GMM #
###################
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_log_pis',
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
self.hyperparams['GMM_components']))
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_mus',
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
self.hyperparams['GMM_components'] * self.pred_state_length))
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_log_sigmas',
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
self.hyperparams['GMM_components'] * self.pred_state_length))
self.add_submodule(self.node_type + '/decoder/proj_to_GMM_corrs',
model_if_absent=nn.Linear(self.hyperparams['dec_rnn_dim'],
self.hyperparams['GMM_components']))
for name, module in self.node_modules.items():
module.to(self.device)
def create_new_scheduler(self, name, annealer, annealer_kws, creation_condition=True):
value_scheduler = None
rsetattr(self, name + '_scheduler', value_scheduler)
if creation_condition:
annealer_kws['device'] = self.device
value_annealer = annealer(annealer_kws)
rsetattr(self, name + '_annealer', value_annealer)
# This is the value that we'll update on each call of
# step_annealers().
with warnings.catch_warnings():
warnings.simplefilter("ignore")
rsetattr(self, name, torch.tensor(value_annealer(0), device=self.device))
dummy_optimizer = optim.Optimizer([rgetattr(self, name)],
{'lr': torch.tensor(value_annealer(0), device=self.device)})
rsetattr(self, name + '_optimizer', dummy_optimizer)
value_scheduler = CustomLR(dummy_optimizer,
value_annealer)
rsetattr(self, name + '_scheduler', value_scheduler)
self.schedulers.append(value_scheduler)
self.annealed_vars.append(name)
def set_annealing_params(self):
self.schedulers = list()
self.annealed_vars = list()
self.create_new_scheduler(name='kl_weight',
annealer=sigmoid_anneal,
annealer_kws={
'start': self.hyperparams['kl_weight_start'],
'finish': self.hyperparams['kl_weight'],
'center_step': self.hyperparams['kl_crossover'],
'steps_lo_to_hi': self.hyperparams['kl_crossover'] / self.hyperparams[
'kl_sigmoid_divisor']
},
creation_condition=((np.abs(self.hyperparams['alpha'] - 1.0) < 1e-3)
and (not self.hyperparams['use_iwae'])))
self.create_new_scheduler(name='dec_sample_model_prob',
annealer=sigmoid_anneal,
annealer_kws={
'start': self.hyperparams['dec_sample_model_prob_start'],
'finish': self.hyperparams['dec_sample_model_prob_final'],
'center_step': self.hyperparams['dec_sample_model_prob_crossover'],
'steps_lo_to_hi': self.hyperparams['dec_sample_model_prob_crossover'] /
self.hyperparams['dec_sample_model_prob_divisor']
},
creation_condition=self.hyperparams['sample_model_during_dec'])
self.create_new_scheduler(name='latent.temp',
annealer=exp_anneal,
annealer_kws={
'start': self.hyperparams['tau_init'],
'finish': self.hyperparams['tau_final'],
'rate': self.hyperparams['tau_decay_rate']
})
self.create_new_scheduler(name='latent.z_logit_clip',
annealer=sigmoid_anneal,
annealer_kws={
'start': self.hyperparams['z_logit_clip_start'],
'finish': self.hyperparams['z_logit_clip_final'],
'center_step': self.hyperparams['z_logit_clip_crossover'],
'steps_lo_to_hi': self.hyperparams['z_logit_clip_crossover'] / self.hyperparams[
'z_logit_clip_divisor']
},
creation_condition=self.hyperparams['use_z_logit_clipping'])
self.create_new_scheduler(name='warmup_dropout_keep',
annealer=sigmoid_anneal,
annealer_kws={
'start': self.hyperparams['inf_warmup_start'],
'finish': self.hyperparams['inf_warmup'],
'center_step': self.hyperparams['inf_warmup_crossover'],
'steps_lo_to_hi': self.hyperparams['inf_warmup_crossover'] / self.hyperparams[
'inf_warmup_sigmoid_divisor']
})
def step_annealers(self):
# This should manage all of the step-wise changed
# parameters automatically.
for idx, annealed_var in enumerate(self.annealed_vars):
if rgetattr(self, annealed_var + '_scheduler') is not None:
# First we step the scheduler.
rgetattr(self, annealed_var + '_scheduler').step()
# Then we set the annealed vars' value.
rsetattr(self, annealed_var, rgetattr(self, annealed_var + '_optimizer').param_groups[0]['lr'])
self.summarize_annealers()
def summarize_annealers(self):
if self.log_writer is not None:
for annealed_var in self.annealed_vars:
if rgetattr(self, annealed_var) is not None:
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), annealed_var.replace('.', '/')),
rgetattr(self, annealed_var), self.curr_iter)
def obtain_encoded_tensor_dict(self,
mode,
timestep,
timesteps_in_scene,
inputs,
inputs_st,
labels,
labels_st,
first_history_indices,
scene,
node_scene_graph_batched):
tensor_dict = dict() # tensor_dict
batch_size = inputs.shape[0]
#########################################
# Provide basic information to encoders #
#########################################
node_traj = inputs
node_history = inputs[:, :timestep + 1]
node_present_state = inputs[:, timestep]
node_pos = inputs[:, timestep, 0:2]
node_vel = inputs[:, timestep, 2:4]
node_traj_st = inputs_st
node_history_st = inputs_st[:, :timestep + 1]
node_present_state_st = inputs_st[:, timestep]
node_pos_st = inputs_st[:, timestep, 0:2]
node_vel_st = inputs_st[:, timestep, 2:4]
if self.hyperparams['incl_robot_node'] and scene.robot is not None:
robot_traj_list = []
for timestep_s in timesteps_in_scene:
timestep_range = np.array([timestep_s - self.max_hl, timestep_s + node_traj.shape[1] - self.max_hl - 1])
robot_traj_list.append(scene.robot.get(timestep_range, self.state[scene.robot.type.name]))
robot_traj_np = np.array(robot_traj_list)
# Make Robot State relative to node
_, std = self.env.get_standardize_params(self.state[scene.robot.type.name], node_type=scene.robot.type)
std[0:2] = 40
rel_state = np.zeros_like(robot_traj_np)
rel_state[..., :6] = node_traj[..., :6].cpu()
robot_traj_np_st = self.env.standardize(robot_traj_np,
self.state[scene.robot.type.name],
node_type=scene.robot.type,
mean=rel_state,
std=std)
robot_traj_st = torch.tensor(robot_traj_np_st).float().to(self.device)
robot_present_state_st = robot_traj_st[:, timestep]
robot_future_st = robot_traj_st[:, timestep+1:]
tensor_dict['robot_present'] = robot_present_state_st
tensor_dict['robot_future'] = robot_future_st
##################
# Encode History #
##################
tensor_dict['node_history_encoded'] = self.encode_node_history(mode,
node_history_st,
first_history_indices,
timestep)
##################
# Encode Present #
##################
tensor_dict['node_present'] = node_present_state_st # [bs, state_dim]
##################
# Encode Future #
##################
if mode != ModeKeys.PREDICT:
tensor_dict['node_future'] = labels_st[:, timestep + 1:timestep + self.ph + 1] # [bs, ph, state_dim]
#######################################
# Encode Joint Present (Robot + Node) #
#######################################
if self.warmup_dropout_keep < 0.5:
if self.hyperparams['incl_robot_node'] and scene.robot is not None:
tensor_dict['joint_present'] = torch.zeros_like(torch.cat([robot_present_state_st,
labels_st[:, timestep]], dim=1))
else:
tensor_dict['joint_present'] = torch.zeros_like(labels_st[:, timestep])
else:
if self.hyperparams['incl_robot_node'] and scene.robot is not None:
tensor_dict['joint_present'] = torch.cat([robot_present_state_st, labels_st[:, timestep]], dim=1)
else:
tensor_dict['joint_present'] = labels_st[:, timestep]
##############################
# Encode Node Edges per Type #
##############################
tensor_dict["node_edges_encoded"] = list()
for edge_type in self.edge_types:
connected_nodes_batched = list()
edge_masks_batched = list()
for i, (node, scene_graph) in enumerate(node_scene_graph_batched):
# We get all nodes which are connected to the current node for the current timestep
connected_nodes_batched.append(scene_graph.get_neighbors(node, edge_type[1]))
if self.hyperparams['dynamic_edges'] == 'yes':
# We get the edge masks for the current node at the current timestep
edge_masks_for_node = scene_graph.get_edge_scaling(node)
edge_masks_batched.append(torch.tensor(edge_masks_for_node).float().to(self.device))
# Encode edges for given edge type
encoded_edges_type = self.encode_edge(mode,
node_history,
node_history_st,
edge_type,
connected_nodes_batched,
edge_masks_batched,
first_history_indices,
timestep,
timesteps_in_scene,
scene)
tensor_dict["node_edges_encoded"].append(encoded_edges_type) # List of [bs/nbs, enc_rnn_dim]
#####################
# Encode Node Edges #
#####################
tensor_dict["total_edge_influence"] = self.encode_total_edge_influence(mode,
tensor_dict["node_edges_encoded"],
tensor_dict["node_history_encoded"],
batch_size) # [bs/nbs, 4*enc_rnn_dim]
#print(time.time() - t)
##############
# Encode Map #
##############
if mode == ModeKeys.TRAIN:
rand_heading = (2 * np.random.rand(node_present_state.shape[0]) - 1) * 5 * np.pi / 180 # outside if because seeding
else:
rand_heading = 0.
if self.hyperparams['use_map_encoding']:
heading = node_present_state.cpu().numpy()[:, -1] + rand_heading
node_pos_cpu = node_pos.cpu().numpy()
if self.node_type == 'VEHICLE':
node_pos_cpu = node_pos_cpu + 20 * np.array([np.cos(heading), np.sin(heading)]).T
cropped_maps_np = get_cropped_maps_heading_exact(world_pts=node_pos_cpu,
map=scene.map[self.node_type],
context_size=self.hyperparams['map_context'],
heading=heading)
cropped_maps_np = np.swapaxes(cropped_maps_np, -1, 1)
cropped_maps = torch.from_numpy(cropped_maps_np).to(self.device)
del cropped_maps_np
encoded_map = self.node_modules[self.node_type + '/map_encoder'](cropped_maps)
encoded_map = F.dropout(encoded_map, 0.5, training=(mode == ModeKeys.TRAIN))
tensor_dict["encoded_maps"] = encoded_map
if self.log_writer is not None and mode != ModeKeys.PREDICT:
context_size = self.hyperparams['map_context']
#cropped_maps = cropped_maps.clone()
#cropped_maps[:, :, context_size // 2 - 3:context_size // 2 + 3, context_size // 2 - 3:context_size // 2 + 3] = 1.
self.log_writer.add_images('%s/cropped_maps' % str(self.node_type),
cropped_maps,
self.curr_iter)
img_pts = scene.map[self.node_type].to_map_points(node_pos_cpu)
box_arr = np.empty((img_pts.shape[0], 4))
box_arr[:, 0] = img_pts[:, 0] - context_size // 2
box_arr[:, 1] = img_pts[:, 1] - context_size // 2
box_arr[:, 2] = img_pts[:, 0] + context_size // 2
box_arr[:, 3] = img_pts[:, 1] + context_size // 2
self.log_writer.add_image_with_boxes('%s/cropped_locs' % str(self.node_type),
np.swapaxes(scene.map[self.node_type].fdata, -1, 0).astype(float),
box_arr,
self.curr_iter)
######################################
# Concatenate Encoder Outputs into x #
######################################
concat_list = list()
if self.hyperparams['use_map_encoding']:
concat_list.append(tensor_dict["encoded_maps"]) # [bs/nbs, map_enc_output_size]
# Every node has an edge-influence encoder (which could just be zero).
if self.warmup_dropout_keep < 0.5:
concat_list.append(torch.zeros_like(tensor_dict["total_edge_influence"]))
else:
concat_list.append(tensor_dict["total_edge_influence"]) # [bs/nbs, 4*enc_rnn_dim]
# Every node has a history encoder.
if self.warmup_dropout_keep < 0.5:
concat_list.append(torch.zeros_like(tensor_dict["node_history_encoded"]))
else:
concat_list.append(tensor_dict["node_history_encoded"]) # [bs/nbs, enc_rnn_dim_history]
if self.hyperparams['incl_robot_node'] and scene.robot is not None:
tensor_dict[scene.robot.type.name + "_robot_future_encoder"] = self.encode_robot_future(
tensor_dict['robot_present'],
tensor_dict['robot_future'],
mode,
scene.robot.type.name + '_robot')
# [bs/nbs, 4*enc_rnn_dim_future]
concat_list.append(tensor_dict[scene.robot.type.name + "_robot_future_encoder"])
elif self.hyperparams['incl_robot_node']:
# Four times because we're trying to mimic a bi-directional RNN's output (which is c and h from both ends).
concat_list.append(
torch.zeros([batch_size, 4 * self.hyperparams['enc_rnn_dim_future']], device=self.device))
tensor_dict["x"] = torch.cat(concat_list, dim=1)
if mode == ModeKeys.TRAIN or mode == ModeKeys.EVAL:
tensor_dict[self.node_type + "_future_encoder"] = self.encode_node_future(tensor_dict['node_present'],
tensor_dict['node_future'],
mode,
self.node_type)
return tensor_dict
def encode_node_history(self, mode, node_traj, first_history_indices, timestep):
outputs, _ = run_lstm_on_variable_length_seqs(self.node_modules[self.node_type + '/node_history_encoder'],
node_traj,
torch.ones_like(first_history_indices) * timestep,
first_history_indices,
self.hyperparams[
'maximum_history_length'] + 1) # history + current
outputs = F.dropout(outputs,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN)) # [bs, max_time, enc_rnn_dim]
last_index_per_sequence = timestep - first_history_indices
return outputs[torch.arange(first_history_indices.shape[0]), last_index_per_sequence]
def encode_edge(self,
mode,
node_history,
node_history_st,
edge_type,
connected_nodes,
edge_masks,
first_history_indices,
timestep,
timesteps_in_scene,
scene):
max_hl = self.hyperparams['maximum_history_length']
edge_states_list = list() # list of [#of neighbors, max_hl, state_dim]
for i, timestep_in_scene in enumerate(timesteps_in_scene): # Get neighbors for timestep in batch
neighbor_states = list()
for node in connected_nodes[i]:
neighbor_state_np = node.get(np.array([timestep_in_scene - max_hl, timestep_in_scene]),
self.state[node.type.name],
padding=0.0)
# Make State relative to node
_, std = self.env.get_standardize_params(self.state[node.type.name], node_type=node.type)
std[0:2] = self.env.attention_radius[edge_type]
rel_state = np.zeros_like(neighbor_state_np)
rel_state[:, :6] = node_history[i, -1, :6].cpu()
neighbor_state_np_st = self.env.standardize(neighbor_state_np,
self.state[node.type.name],
node_type=node.type,
mean=rel_state,
std=std)
neighbor_state = torch.tensor(neighbor_state_np_st).float().to(self.device)
neighbor_states.append(neighbor_state)
if len(neighbor_states) == 0: # There are no neighbors for edge type # TODO necessary?
neighbor_state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[edge_type[1].name].values()]))
edge_states_list.append(torch.zeros((1, max_hl + 1, neighbor_state_length), device=self.device))
else:
edge_states_list.append(torch.stack(neighbor_states, dim=0))
if self.hyperparams['edge_state_combine_method'] == 'sum':
# Used in Structural-RNN to combine edges as well.
op_applied_edge_states_list = list()
for neighbors_state in edge_states_list:
op_applied_edge_states_list.append(torch.sum(neighbors_state, dim=0))
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
if self.hyperparams['dynamic_edges'] == 'yes':
# Should now be (bs, time, 1)
op_applied_edge_mask_list = list()
for edge_mask in edge_masks:
op_applied_edge_mask_list.append(torch.clamp(torch.sum(edge_mask, dim=0, keepdim=True), max=1.))
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
elif self.hyperparams['edge_state_combine_method'] == 'max':
# Used in NLP, e.g. max over word embeddings in a sentence.
op_applied_edge_states_list = list()
for neighbors_state in edge_states_list:
op_applied_edge_states_list.append(torch.max(neighbors_state, dim=0))
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
if self.hyperparams['dynamic_edges'] == 'yes':
# Should now be (bs, time, 1)
op_applied_edge_mask_list = list()
for edge_mask in edge_masks:
op_applied_edge_mask_list.append(torch.clamp(torch.max(edge_mask, dim=0, keepdim=True), max=1.))
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
elif self.hyperparams['edge_state_combine_method'] == 'mean':
# Used in NLP, e.g. mean over word embeddings in a sentence.
op_applied_edge_states_list = list()
for neighbors_state in edge_states_list:
op_applied_edge_states_list.append(torch.mean(neighbors_state, dim=0))
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
if self.hyperparams['dynamic_edges'] == 'yes':
# Should now be (bs, time, 1)
op_applied_edge_mask_list = list()
for edge_mask in edge_masks:
op_applied_edge_mask_list.append(torch.clamp(torch.mean(edge_mask, dim=0, keepdim=True), max=1.))
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
joint_history = torch.cat([combined_neighbors, node_history_st], dim=-1)
outputs, _ = run_lstm_on_variable_length_seqs(self.node_modules[self.env.edge_type_str(edge_type) + '/edge_encoder'],
joint_history,
torch.ones_like(first_history_indices) * timestep,
first_history_indices,
self.hyperparams[
'maximum_history_length'] + 1) # Add prediction timestep
outputs = F.dropout(outputs,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN)) # [bs, max_time, enc_rnn_dim]
last_index_per_sequence = timestep - first_history_indices
ret = outputs[torch.arange(last_index_per_sequence.shape[0]), last_index_per_sequence]
if self.hyperparams['dynamic_edges'] == 'yes':
return ret * combined_edge_masks
else:
return ret
def encode_total_edge_influence(self, mode, encoded_edges, node_history_encoder, batch_size):
if self.hyperparams['edge_influence_combine_method'] == 'sum':
stacked_encoded_edges = torch.stack(encoded_edges, dim=0)
combined_edges = torch.sum(stacked_encoded_edges, dim=0)
elif self.hyperparams['edge_influence_combine_method'] == 'mean':
stacked_encoded_edges = torch.stack(encoded_edges, dim=0)
combined_edges = torch.mean(stacked_encoded_edges, dim=0)
elif self.hyperparams['edge_influence_combine_method'] == 'max':
stacked_encoded_edges = torch.stack(encoded_edges, dim=0)
combined_edges = torch.max(stacked_encoded_edges, dim=0)
elif self.hyperparams['edge_influence_combine_method'] == 'bi-rnn':
if len(encoded_edges) == 0:
combined_edges = torch.zeros((batch_size, self.eie_output_dims), device=self.device)
else:
# axis=1 because then we get size [batch_size, max_time, depth]
encoded_edges = torch.stack(encoded_edges, dim=1)
_, state = self.node_modules[self.node_type + '/edge_influence_encoder'](encoded_edges)
combined_edges = unpack_RNN_state(state)
combined_edges = F.dropout(combined_edges,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
elif self.hyperparams['edge_influence_combine_method'] == 'attention':
# Used in Social Attention (https://arxiv.org/abs/1710.04689)
if len(encoded_edges) == 0:
combined_edges = torch.zeros((batch_size, self.eie_output_dims), device=self.device)
else:
# axis=1 because then we get size [batch_size, max_time, depth]
encoded_edges = torch.stack(encoded_edges, dim=1)
combined_edges, _ = self.node_modules[self.node_type + '/edge_influence_encoder'](encoded_edges,
node_history_encoder)
combined_edges = F.dropout(combined_edges,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
return combined_edges
def encode_node_future(self, node_present, node_future, mode, scope):
initial_h_model = self.node_modules[self.node_type + '/node_future_encoder/initial_h']
initial_c_model = self.node_modules[self.node_type + '/node_future_encoder/initial_c']
# Here we're initializing the forward hidden states,
# but zeroing the backward ones.
initial_h = initial_h_model(node_present)
initial_h = torch.stack([initial_h, torch.zeros_like(initial_h, device=self.device)], dim=0)
initial_c = initial_c_model(node_present)
initial_c = torch.stack([initial_c, torch.zeros_like(initial_c, device=self.device)], dim=0)
initial_state = (initial_h, initial_c)
_, state = self.node_modules[self.node_type + '/node_future_encoder'](node_future, initial_state)
state = unpack_RNN_state(state)
state = F.dropout(state,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
return state
def encode_robot_future(self, robot_present, robot_future, mode, scope):
initial_h_model = self.node_modules['robot_future_encoder/initial_h']
initial_c_model = self.node_modules['robot_future_encoder/initial_c']
# Here we're initializing the forward hidden states,
# but zeroing the backward ones.
initial_h = initial_h_model(robot_present)
initial_h = torch.stack([initial_h, torch.zeros_like(initial_h, device=self.device)], dim=0)
initial_c = initial_c_model(robot_present)
initial_c = torch.stack([initial_c, torch.zeros_like(initial_c, device=self.device)], dim=0)
initial_state = (initial_h, initial_c)
_, state = self.node_modules['robot_future_encoder'](robot_future, initial_state)
state = unpack_RNN_state(state)
state = F.dropout(state,
p=1. - self.hyperparams['rnn_kwargs']['dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
return state
def q_z_xy(self, x, y, mode):
xy = torch.cat([x, y], dim=1)
if self.hyperparams['q_z_xy_MLP_dims'] is not None:
dense = self.node_modules[self.node_type + '/q_z_xy']
h = F.dropout(F.relu(dense(xy)),
p=1. - self.hyperparams['MLP_dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
else:
h = xy
to_latent = self.node_modules[self.node_type + '/hxy_to_z']
return self.latent.dist_from_h(to_latent(h), mode)
def p_z_x(self, x, mode):
if self.hyperparams['p_z_x_MLP_dims'] is not None:
dense = self.node_modules[self.node_type + '/p_z_x']
h = F.dropout(F.relu(dense(x)),
p=1. - self.hyperparams['MLP_dropout_keep_prob'],
training=(mode == ModeKeys.TRAIN))
else:
h = x
to_latent = self.node_modules[self.node_type + '/hx_to_z']
return self.latent.dist_from_h(to_latent(h), mode)
def project_to_GMM_params(self, tensor):
log_pis = self.node_modules[self.node_type + '/decoder/proj_to_GMM_log_pis'](tensor)
mus = self.node_modules[self.node_type + '/decoder/proj_to_GMM_mus'](tensor)
log_sigmas = self.node_modules[self.node_type + '/decoder/proj_to_GMM_log_sigmas'](tensor)
corrs = torch.tanh(self.node_modules[self.node_type + '/decoder/proj_to_GMM_corrs'](tensor))
return log_pis, mus, log_sigmas, corrs
def p_y_xz(self, x, z_stacked, tensor_dict, mode,
num_predicted_timesteps, num_samples_z, num_samples_gmm=1, most_likely_gmm=False):
ph = num_predicted_timesteps
our_future = "node_future"
robot_future = "robot_future"
k = num_samples_z * num_samples_gmm
GMM_c, pred_dim = self.hyperparams['GMM_components'], self.pred_state_length
z = torch.reshape(z_stacked, (-1, self.latent.z_dim))
zx = torch.cat([z, x.repeat(k, 1)], dim=1)
cell = self.node_modules[self.node_type + '/decoder/lstm_cell']
initial_h_model = self.node_modules[self.node_type + '/decoder/initial_h']
initial_state = initial_h_model(zx)
log_pis, mus, log_sigmas, corrs = [], [], [], []
if mode in [ModeKeys.TRAIN, ModeKeys.EVAL]:
state = initial_state
if self.hyperparams['sample_model_during_dec'] and mode == ModeKeys.TRAIN:
input_ = torch.cat([zx, tensor_dict['joint_present'].repeat(k, 1)], dim=1)
for j in range(ph):
h_state = cell(input_, state)
log_pi_t, mu_t, log_sigma_t, corr_t = self.project_to_GMM_params(h_state)
y_t = GMM2D(log_pi_t, mu_t, log_sigma_t, corr_t, self.pred_state_length, self.device,
self.hyperparams['log_sigma_min'],
self.hyperparams['log_sigma_max']).sample() # [k;bs, pred_dim]
# This is where we pick our output y_t or the true output
# our_future to pass into the next cell (we do this with
# probability self.dec_sample_model_prob and is only done
# during training).
mask = td.Bernoulli(probs=self.dec_sample_model_prob).sample((y_t.size()[0], 1))
y_t = mask * y_t + (1 - mask) * (tensor_dict[our_future][:, j, :].repeat(k, 1))
log_pis.append(log_pi_t)
mus.append(mu_t)
log_sigmas.append(log_sigma_t)
corrs.append(corr_t)
if self.hyperparams['incl_robot_node']:
dec_inputs = torch.cat([tensor_dict[robot_future][:, j, :].repeat(k, 1), y_t], dim=1)
else:
dec_inputs = y_t
input_ = torch.cat([zx, dec_inputs], dim=1)
state = h_state
log_pis = torch.stack(log_pis, dim=1)
mus = torch.stack(mus, dim=1)
log_sigmas = torch.stack(log_sigmas, dim=1)
corrs = torch.stack(corrs, dim=1)
else:
zx_with_time_dim = zx.unsqueeze(dim=1) # [k;bs/nbs, 1, N*K + 2*enc_rnn_dim]
zx_time_tiled = zx_with_time_dim.repeat(1, ph, 1)
if self.hyperparams['incl_robot_node']:
dec_inputs = torch.cat([
tensor_dict["joint_present"].unsqueeze(dim=1),
torch.cat([tensor_dict[robot_future][:, :ph - 1, :], tensor_dict[our_future][:, :ph - 1, :]],
dim=2)
], dim=1)
else:
dec_inputs = torch.cat([
tensor_dict["joint_present"].unsqueeze(dim=1),
tensor_dict[our_future][:, :ph - 1, :]
], dim=1)
outputs = list()
for j in range(ph):
inputs = torch.cat([zx_time_tiled, dec_inputs.repeat(k, 1, 1)],
dim=2)
h_state = cell(inputs[:, j, :], state)
outputs.append(h_state)
state = h_state
outputs = torch.stack(outputs, dim=1)
log_pis, mus, log_sigmas, corrs = self.project_to_GMM_params(outputs)
if self.hyperparams['log_histograms'] and self.log_writer is not None:
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'GMM_log_pis'), log_pis, self.curr_iter)
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'GMM_mus'), mus, self.curr_iter)
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'GMM_log_sigmas'), log_sigmas,
self.curr_iter)
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'GMM_corrs'), corrs, self.curr_iter)
elif mode == ModeKeys.PREDICT:
input_ = torch.cat([zx, tensor_dict["joint_present"].repeat(k, 1)], dim=1)
state = initial_state
log_pis, mus, log_sigmas, corrs, y = [], [], [], [], []
for j in range(ph):
h_state = cell(input_, state)
log_pi_t, mu_t, log_sigma_t, corr_t = self.project_to_GMM_params(h_state)
gmm = GMM2D(log_pi_t, mu_t, log_sigma_t, corr_t, self.pred_state_length, self.device,
self.hyperparams['log_sigma_min'],
self.hyperparams['log_sigma_max']) # [k;bs, pred_dim]
if most_likely_gmm:
y_t_list = []
for i in range(gmm.mus.shape[0]):
gmm_i = GMM2D(log_pi_t[i], mu_t[i], log_sigma_t[i], corr_t[i], self.pred_state_length,
self.device,
self.hyperparams['log_sigma_min'],
self.hyperparams['log_sigma_max']) # [k;bs, pred_dim]
x_min = gmm.mus[i, ..., 0].min()
x_max = gmm.mus[i, ..., 0].max()
y_min = gmm.mus[i, ..., 1].min()
y_max = gmm.mus[i, ..., 1].max()
x_min = x_min - 0.5 * torch.abs(x_min)
x_max = x_max + 0.5 * torch.abs(x_max)
y_min = y_min - 0.5 * torch.abs(y_min)
y_max = y_max + 0.5 * torch.abs(y_max)
search_grid = torch.stack(torch.meshgrid([torch.arange(x_min, x_max, 0.01),
torch.arange(y_min, y_max, 0.01)]), dim=2
).view(1, -1, 2).float().to(
self.device)
ll_score = gmm_i.log_prob(search_grid).squeeze()
y_t_list.append(search_grid[0, torch.argmax(ll_score, dim=0)])
y_t = torch.stack(y_t_list, dim=0)
else:
y_t = gmm.sample()
log_pis.append(log_pi_t)
mus.append(mu_t)
log_sigmas.append(log_sigma_t)
corrs.append(corr_t)
y.append(y_t)
if self.hyperparams['incl_robot_node']:
dec_inputs = torch.cat([tensor_dict[robot_future][:, j, :].repeat(k, 1), y_t], dim=1)
else:
dec_inputs = y_t
input_ = torch.cat([zx, dec_inputs],dim=1)
state = h_state
log_pis = torch.stack(log_pis, dim=1)
mus = torch.stack(mus, dim=1)
log_sigmas = torch.stack(log_sigmas, dim=1)
corrs = torch.stack(corrs, dim=1)
sampled_future = torch.reshape(torch.stack(y, dim=1), (num_samples_z, num_samples_gmm, -1, ph, pred_dim))
y_dist = GMM2D(torch.reshape(log_pis, [k, -1, ph, GMM_c]),
torch.reshape(mus, [k, -1, ph, GMM_c * pred_dim]),
torch.reshape(log_sigmas, [k, -1, ph, GMM_c * pred_dim]),
torch.reshape(corrs, [k, -1, ph, GMM_c]),
self.pred_state_length, self.device,
self.hyperparams['log_sigma_min'], self.hyperparams['log_sigma_max'])
if mode == ModeKeys.PREDICT:
return y_dist, sampled_future
else:
return y_dist
def encoder(self, x, y, mode, num_samples=None):
if mode == ModeKeys.TRAIN:
sample_ct = self.hyperparams['k']
elif mode == ModeKeys.EVAL:
sample_ct = self.hyperparams['k_eval']
elif mode == ModeKeys.PREDICT:
sample_ct = num_samples
if num_samples is None:
raise ValueError("num_samples cannot be None with mode == PREDICT.")
self.latent.q_dist = self.q_z_xy(x, y, mode)
self.latent.p_dist = self.p_z_x(x, mode)
z = self.latent.sample_q(sample_ct, mode)
if mode == ModeKeys.TRAIN and self.hyperparams['kl_exact']:
kl_obj = self.latent.kl_q_p(self.log_writer, '%s' % str(self.node_type), self.curr_iter)
if self.log_writer is not None:
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'kl'), kl_obj, self.curr_iter)
else:
kl_obj = None
return z, kl_obj
def decoder(self, x, y, z, tensor_dict, mode, num_predicted_timesteps, num_samples):
y_dist = self.p_y_xz(x, z, tensor_dict, mode, num_predicted_timesteps, num_samples)
log_p_yt_xz = torch.clamp(y_dist.log_prob(y), max=self.hyperparams['log_p_yt_xz_max'])
if self.hyperparams['log_histograms'] and self.log_writer is not None:
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'log_p_yt_xz'), log_p_yt_xz, self.curr_iter)
log_p_y_xz = torch.sum(log_p_yt_xz, dim=2)
return log_p_y_xz
def mutual_inf_mc(self, x_dist):
dist = x_dist.__class__
H_y = dist(probs=x_dist.probs.mean(dim=0)).entropy()
return (H_y - x_dist.entropy().mean(dim=0)).sum()
def train_loss(self,
inputs,
inputs_st,
first_history_indices,
labels,
labels_st,
scene,
node_scene_graph_batched,
timestep,
timesteps_in_scene,
prediction_horizon):
mode = ModeKeys.TRAIN
tensor_dict = self.obtain_encoded_tensor_dict(mode,
timestep,
timesteps_in_scene,
inputs,
inputs_st,
labels,
labels_st,
first_history_indices,
scene,
node_scene_graph_batched)
z, kl = self.encoder(tensor_dict["x"], tensor_dict[self.node_type + "_future_encoder"], mode)
log_p_y_xz = self.decoder(tensor_dict["x"], tensor_dict["node_future"], z, tensor_dict, mode,
prediction_horizon,
self.hyperparams['k'])
if np.abs(self.hyperparams['alpha'] - 1.0) < 1e-3 and not self.hyperparams['use_iwae']:
log_p_y_xz_mean = torch.mean(log_p_y_xz, dim=0) # [nbs]
log_likelihood = torch.mean(log_p_y_xz_mean)
mutual_inf_q = self.mutual_inf_mc(self.latent.q_dist)
mutual_inf_p = self.mutual_inf_mc(self.latent.p_dist)
ELBO = log_likelihood - self.kl_weight * kl + 1. * mutual_inf_p
loss = -ELBO
if self.hyperparams['log_histograms'] and self.log_writer is not None:
self.log_writer.add_histogram('%s/%s' % (str(self.node_type), 'log_p_y_xz'),
log_p_y_xz_mean,
self.curr_iter)
else:
log_q_z_xy = self.latent.q_log_prob(z) # [k, nbs]
log_p_z_x = self.latent.p_log_prob(z) # [k, nbs]
a = self.hyperparams['alpha']
log_pp_over_q = log_p_y_xz + log_p_z_x - log_q_z_xy
log_likelihood = (torch.mean(torch.logsumexp(log_pp_over_q * (1. - a), dim=0))
- torch.log(self.hyperparams['k'])) / (1. - a)
loss = -log_likelihood
if self.log_writer is not None:
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'mutual_information_q'),
mutual_inf_q,
self.curr_iter)
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'mutual_information_p'),
mutual_inf_p,
self.curr_iter)
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'log_likelihood'),
log_likelihood,
self.curr_iter)
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'loss'),
loss,
self.curr_iter)
if self.hyperparams['log_histograms']:
self.latent.summarize_for_tensorboard(self.log_writer, str(self.node_type), self.curr_iter)
return loss
def eval_loss(self,
inputs,
inputs_st,
first_history_indices,
labels,
labels_st,
scene,
node_scene_graph_batched,
timestep,
timesteps_in_scene,
prediction_horizon,
compute_naive=True,
compute_exact=True,
compute_sample=True):
mode = ModeKeys.EVAL
tensor_dict = self.obtain_encoded_tensor_dict(mode,
timestep,
timesteps_in_scene,
inputs,
inputs_st,
labels,
labels_st,
first_history_indices,
scene,
node_scene_graph_batched)
### Importance sampled NLL estimate
z, _ = self.encoder(tensor_dict["x"], tensor_dict[self.node_type + "_future_encoder"],
mode) # [k_eval, nbs, N*K]
log_p_y_xz = self.decoder(tensor_dict["x"], tensor_dict['node_future'], z, tensor_dict, mode,
prediction_horizon,
self.hyperparams['k_eval']) # [k_eval, nbs]
log_q_z_xy = self.latent.q_log_prob(z) # [k_eval, nbs]
log_p_z_x = self.latent.p_log_prob(z) # [k_eval, nbs]
log_likelihood = torch.mean(torch.logsumexp(log_p_y_xz + log_p_z_x - log_q_z_xy, dim=0)) - \
torch.log(torch.tensor(self.hyperparams['k_eval'], dtype=torch.float, device=self.device))
nll_q_is = -log_likelihood
### Naive sampled NLL estimate
nll_p = torch.tensor(np.nan)
if compute_naive:
z = self.latent.sample_p(self.hyperparams['k_eval'], mode)
log_p_y_xz = self.decoder(tensor_dict["x"], tensor_dict['node_future'], z, tensor_dict, mode,
prediction_horizon,
self.hyperparams['k_eval'])
log_likelihood_p = torch.mean(torch.logsumexp(log_p_y_xz, dim=0)) - \
torch.log(
torch.tensor(self.hyperparams['k_eval'], dtype=torch.float, device=self.device))
nll_p = -log_likelihood_p
### Exact NLL
nll_exact = torch.tensor(np.nan)
if compute_exact:
K, N = self.hyperparams['K'], self.hyperparams['N']
if K ** N < 50:
nbs = tensor_dict["x"].size()[0]
z_raw = torch.from_numpy(
DiscreteLatent.all_one_hot_combinations(N, K).astype(np.float32)
).to(self.device).repeat(1, nbs) # [K**N, nbs*N*K]
z = torch.reshape(z_raw, (K ** N, -1, N * K)) # [K**N, nbs, N*K]
log_p_y_xz = self.decoder(tensor_dict["x"], tensor_dict['node_future'], z, tensor_dict, mode,
prediction_horizon,
K ** N) # [K**N, nbs]
log_p_z_x = self.latent.p_log_prob(z) # [K**N, nbs]
exact_log_likelihood = torch.mean(torch.logsumexp(log_p_y_xz + log_p_z_x, dim=0))
nll_exact = -exact_log_likelihood
nll_sampled = torch.tensor(np.nan)
if compute_sample:
z = self.latent.sample_p(self.hyperparams['k_eval'], mode)
y_dist, _ = self.p_y_xz(tensor_dict["x"], z, tensor_dict, ModeKeys.PREDICT, prediction_horizon,
self.hyperparams['k_eval'])
log_p_yt_xz = torch.clamp(y_dist.log_prob(tensor_dict['node_future']),
max=self.hyperparams['log_p_yt_xz_max'])
log_p_y_xz = torch.sum(log_p_yt_xz, dim=2)
log_p_y_xz_mean = torch.mean(log_p_y_xz, dim=0) # [nbs]
log_likelihood = torch.mean(log_p_y_xz_mean)
nll_sampled = -log_likelihood
if self.log_writer is not None:
self.log_writer.add_scalar('%s/%s' % (str(self.node_type), 'log_likelihood_eval'),
log_likelihood,
self.curr_iter)
return nll_q_is, nll_p, nll_exact, nll_sampled
def predict(self,
inputs,
inputs_st,
labels,
labels_st,
first_history_indices,
scene,
node_scene_graph_batched,
timestep,
timesteps_in_scene,
prediction_horizon,
num_samples_z,
num_samples_gmm,
most_likely_z=False,
most_likely_gmm=False,
all_z=False):
mode = ModeKeys.PREDICT
tensor_dict = self.obtain_encoded_tensor_dict(mode,
timestep,
timesteps_in_scene,
inputs,
inputs_st,
labels,
labels_st,
first_history_indices,
scene,
node_scene_graph_batched)
self.latent.p_dist = self.p_z_x(tensor_dict["x"], mode)
z, num_samples_z = self.latent.sample_p(num_samples_z,
mode,
num_samples_gmm=num_samples_gmm,
most_likely=most_likely_z,
all_z=all_z)
y_dist, our_sampled_future = self.p_y_xz(tensor_dict["x"], z, tensor_dict, mode,
prediction_horizon,
num_samples_z,
num_samples_gmm,
most_likely_gmm) # y_dist.mean is [k, bs, ph*state_dim]
return our_sampled_future