388 lines
20 KiB
Python
388 lines
20 KiB
Python
|
import warnings
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import numpy as np
|
||
|
from collections import defaultdict, Counter
|
||
|
from model.components import *
|
||
|
from model.model_utils import *
|
||
|
import model.dynamics as dynamic_module
|
||
|
from model.mgcvae import MultimodalGenerativeCVAE
|
||
|
from environment.scene_graph import DirectedEdge
|
||
|
from environment.node_type import NodeType
|
||
|
|
||
|
|
||
|
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||
|
def __init__(self,
|
||
|
env,
|
||
|
node,
|
||
|
model_registrar,
|
||
|
hyperparams,
|
||
|
device):
|
||
|
self.hyperparams = hyperparams
|
||
|
self.node = node
|
||
|
self.node_type = self.node.type
|
||
|
|
||
|
if len(env.scenes) != 1:
|
||
|
raise ValueError("Passed in Environment has number of scenes != 1")
|
||
|
self.robot = env.scenes[0].robot
|
||
|
self.model_registrar = model_registrar
|
||
|
self.device = device
|
||
|
|
||
|
self.node_modules = dict()
|
||
|
self.env = env
|
||
|
self.scene_graph = None
|
||
|
|
||
|
self.state = self.hyperparams['state']
|
||
|
self.pred_state = self.hyperparams['pred_state'][self.node.type]
|
||
|
self.state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[self.node.type].values()]))
|
||
|
if self.hyperparams['incl_robot_node']:
|
||
|
self.robot_state_length = int(
|
||
|
np.sum([len(entity_dims) for entity_dims in self.state[self.robot.type].values()]))
|
||
|
self.pred_state_length = int(np.sum([len(entity_dims) for entity_dims in self.pred_state.values()]))
|
||
|
|
||
|
self.curr_hidden_states = dict()
|
||
|
self.edge_types = Counter()
|
||
|
|
||
|
self.create_graphical_model()
|
||
|
|
||
|
dynamic_class = getattr(dynamic_module, self.hyperparams['dynamic'][self.node_type]['name'])
|
||
|
dyn_limits = hyperparams['dynamic'][self.node_type]['limits']
|
||
|
self.dynamic = dynamic_class(self.env.scenes[0].dt, dyn_limits, device, self.model_registrar, self.x_size)
|
||
|
|
||
|
def create_graphical_model(self):
|
||
|
"""
|
||
|
Creates or queries all trainable components.
|
||
|
|
||
|
:return: None
|
||
|
"""
|
||
|
self.clear_submodules()
|
||
|
|
||
|
############################
|
||
|
# Everything but Edges #
|
||
|
############################
|
||
|
self.create_node_models()
|
||
|
|
||
|
for name, module in self.node_modules.items():
|
||
|
module.to(self.device)
|
||
|
|
||
|
def update_graph(self, new_scene_graph, new_neighbors, removed_neighbors):
|
||
|
self.scene_graph = new_scene_graph
|
||
|
|
||
|
if self.node in new_neighbors:
|
||
|
for edge_type, new_neighbor_nodes in new_neighbors[self.node].items():
|
||
|
self.add_edge_model(edge_type)
|
||
|
self.edge_types += Counter({edge_type: len(new_neighbor_nodes)})
|
||
|
|
||
|
if self.node in removed_neighbors:
|
||
|
for edge_type, removed_neighbor_nodes in removed_neighbors[self.node].items():
|
||
|
self.remove_edge_model(edge_type)
|
||
|
self.edge_types -= Counter({edge_type: len(removed_neighbor_nodes)})
|
||
|
|
||
|
def get_edge_to(self, other_node):
|
||
|
return DirectedEdge(self.node, other_node)
|
||
|
|
||
|
def add_edge_model(self, edge_type):
|
||
|
if self.hyperparams['edge_encoding']:
|
||
|
if edge_type + '/edge_encoder' not in self.node_modules:
|
||
|
neighbor_state_length = int(
|
||
|
np.sum([len(entity_dims) for entity_dims in
|
||
|
self.state[self._get_other_node_type_from_edge(edge_type)].values()]))
|
||
|
if self.hyperparams['edge_state_combine_method'] == 'pointnet':
|
||
|
self.add_submodule(edge_type + '/pointnet_encoder',
|
||
|
model_if_absent=nn.Sequential(
|
||
|
nn.Linear(self.state_length, 2 * self.state_length),
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(2 * self.state_length, 2 * self.state_length),
|
||
|
nn.ReLU()))
|
||
|
|
||
|
edge_encoder_input_size = 2 * self.state_length + self.state_length
|
||
|
|
||
|
elif self.hyperparams['edge_state_combine_method'] == 'attention':
|
||
|
self.add_submodule(self.node.type + '/edge_attention_combine',
|
||
|
model_if_absent=TemporallyBatchedAdditiveAttention(
|
||
|
encoder_hidden_state_dim=self.state_length,
|
||
|
decoder_hidden_state_dim=self.state_length))
|
||
|
edge_encoder_input_size = self.state_length + neighbor_state_length
|
||
|
|
||
|
else:
|
||
|
edge_encoder_input_size = self.state_length + neighbor_state_length
|
||
|
|
||
|
self.add_submodule(edge_type + '/edge_encoder',
|
||
|
model_if_absent=nn.LSTM(input_size=edge_encoder_input_size,
|
||
|
hidden_size=self.hyperparams['enc_rnn_dim_edge'],
|
||
|
batch_first=True))
|
||
|
|
||
|
def _get_other_node_type_from_edge(self, edge_type_str):
|
||
|
n2_type_str = edge_type_str.split('->')[1]
|
||
|
return NodeType(n2_type_str, self.env.node_type_list.index(n2_type_str) + 1)
|
||
|
|
||
|
def _get_edge_type_from_str(self, edge_type_str):
|
||
|
n1_type_str, n2_type_str = edge_type_str.split('->')
|
||
|
return (NodeType(n1_type_str, self.env.node_type_list.index(n1_type_str) + 1),
|
||
|
NodeType(n2_type_str, self.env.node_type_list.index(n2_type_str) + 1))
|
||
|
|
||
|
def remove_edge_model(self, edge_type):
|
||
|
if self.hyperparams['edge_encoding']:
|
||
|
if len(self.scene_graph.get_neighbors(self.node, self._get_other_node_type_from_edge(edge_type))) == 0:
|
||
|
del self.node_modules[edge_type + '/edge_encoder']
|
||
|
|
||
|
def obtain_encoded_tensors(self, mode, inputs, inputs_st, inputs_np, robot_present_and_future) -> (torch.Tensor,
|
||
|
torch.Tensor,
|
||
|
torch.Tensor,
|
||
|
torch.Tensor,
|
||
|
torch.Tensor):
|
||
|
x, x_r_t, y_r = None, None, None
|
||
|
batch_size = 1
|
||
|
|
||
|
our_inputs = inputs[self.node]
|
||
|
our_inputs_st = inputs_st[self.node]
|
||
|
|
||
|
initial_dynamics = dict()
|
||
|
initial_dynamics['pos'] = our_inputs_st[:, 0:2] # TODO: Generalize
|
||
|
initial_dynamics['vel'] = our_inputs_st[:, 2:4] # TODO: Generalize
|
||
|
self.dynamic.set_initial_condition(initial_dynamics)
|
||
|
|
||
|
#########################################
|
||
|
# Provide basic information to encoders #
|
||
|
#########################################
|
||
|
if self.hyperparams['incl_robot_node'] and self.robot is not None:
|
||
|
x_r_t, y_r = self.get_relative_robot_traj(our_inputs, robot_present_and_future, self.robot.type)
|
||
|
|
||
|
##################
|
||
|
# Encode History #
|
||
|
##################
|
||
|
node_history_encoded = self.encode_node_history(our_inputs_st)
|
||
|
|
||
|
##############################
|
||
|
# Encode Node Edges per Type #
|
||
|
##############################
|
||
|
total_edge_influence = None
|
||
|
if self.hyperparams['edge_encoding']:
|
||
|
node_edges_encoded = list()
|
||
|
for edge_type in self.edge_types:
|
||
|
connected_nodes_batched = list()
|
||
|
edge_masks_batched = list()
|
||
|
|
||
|
# We get all nodes which are connected to the current node for the current timestep
|
||
|
connected_nodes_batched.append(self.scene_graph.get_neighbors(self.node,
|
||
|
self._get_other_node_type_from_edge(
|
||
|
edge_type)))
|
||
|
|
||
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
||
|
# We get the edge masks for the current node at the current timestep
|
||
|
edge_masks_for_node = self.scene_graph.get_edge_scaling(self.node)
|
||
|
edge_masks_batched.append(torch.tensor(edge_masks_for_node, dtype=torch.float, device=self.device))
|
||
|
|
||
|
# Encode edges for given edge type
|
||
|
encoded_edges_type = self.encode_edge(inputs,
|
||
|
inputs_st,
|
||
|
inputs_np,
|
||
|
edge_type,
|
||
|
connected_nodes_batched,
|
||
|
edge_masks_batched)
|
||
|
node_edges_encoded.append(encoded_edges_type) # List of [bs/nbs, enc_rnn_dim]
|
||
|
|
||
|
#####################
|
||
|
# Encode Node Edges #
|
||
|
#####################
|
||
|
total_edge_influence = self.encode_total_edge_influence(mode,
|
||
|
node_edges_encoded,
|
||
|
node_history_encoded,
|
||
|
batch_size)
|
||
|
|
||
|
self.TD = {'node_history_encoded': node_history_encoded,
|
||
|
'total_edge_influence': total_edge_influence}
|
||
|
|
||
|
######################################
|
||
|
# Concatenate Encoder Outputs into x #
|
||
|
######################################
|
||
|
return self.create_encoder_rep(mode, self.TD, x_r_t, y_r)
|
||
|
|
||
|
def create_encoder_rep(self, mode,
|
||
|
TD,
|
||
|
robot_present_st,
|
||
|
robot_future_st):
|
||
|
# Unpacking TD
|
||
|
node_history_encoded = TD['node_history_encoded']
|
||
|
if self.hyperparams['edge_encoding']:
|
||
|
total_edge_influence = TD['total_edge_influence']
|
||
|
|
||
|
if (self.hyperparams['incl_robot_node']
|
||
|
and self.robot is not None
|
||
|
and robot_future_st is not None
|
||
|
and robot_present_st is not None):
|
||
|
robot_future_encoder = self.encode_robot_future(mode, robot_present_st, robot_future_st)
|
||
|
|
||
|
# Tiling for multiple samples
|
||
|
# This tiling is done because:
|
||
|
# a) we must consider the prediction case where there are many candidate robot future actions,
|
||
|
# b) the edge and history encoders are all the same regardless of which candidate future robot action
|
||
|
# we're evaluating.
|
||
|
node_history_encoded = TD['node_history_encoded'].repeat(robot_future_st.size()[0], 1)
|
||
|
if self.hyperparams['edge_encoding']:
|
||
|
total_edge_influence = TD['total_edge_influence'].repeat(robot_future_st.size()[0], 1)
|
||
|
|
||
|
elif self.hyperparams['incl_robot_node'] and self.robot is not None:
|
||
|
# Four times because we're trying to mimic a bi-directional RNN's output (which is c and h from both ends).
|
||
|
robot_future_encoder = torch.zeros([1, 4 * self.hyperparams['enc_rnn_dim_future']], device=self.device)
|
||
|
|
||
|
x_concat_list = list()
|
||
|
|
||
|
# Every node has an edge-influence encoder (which could just be zero).
|
||
|
if self.hyperparams['edge_encoding']:
|
||
|
x_concat_list.append(total_edge_influence) # [bs/nbs, 4*enc_rnn_dim]
|
||
|
|
||
|
# Every node has a history encoder.
|
||
|
x_concat_list.append(node_history_encoded) # [bs/nbs, enc_rnn_dim_history]
|
||
|
|
||
|
if self.hyperparams['incl_robot_node'] and self.robot is not None:
|
||
|
x_concat_list.append(robot_future_encoder) # [bs/nbs, 4*enc_rnn_dim_history]
|
||
|
|
||
|
return torch.cat(x_concat_list, dim=1)
|
||
|
|
||
|
def encode_node_history(self, inputs_st):
|
||
|
new_state = torch.unsqueeze(inputs_st, dim=1) # [bs, 1, state_dim]
|
||
|
if self.node.type + '/node_history_encoder' not in self.curr_hidden_states:
|
||
|
outputs, self.curr_hidden_states[self.node.type + '/node_history_encoder'] = self.node_modules[
|
||
|
self.node.type + '/node_history_encoder'](new_state)
|
||
|
else:
|
||
|
outputs, self.curr_hidden_states[self.node.type + '/node_history_encoder'] = self.node_modules[
|
||
|
self.node.type + '/node_history_encoder'](new_state, self.curr_hidden_states[
|
||
|
self.node.type + '/node_history_encoder'])
|
||
|
|
||
|
return outputs[:, 0, :]
|
||
|
|
||
|
def encode_edge(self, inputs, inputs_st, inputs_np, edge_type, connected_nodes, edge_masks):
|
||
|
edge_type_tuple = self._get_edge_type_from_str(edge_type)
|
||
|
edge_states_list = list() # list of [#of neighbors, max_ht, state_dim]
|
||
|
neighbor_states = list()
|
||
|
|
||
|
rel_state = inputs[self.node].cpu().numpy()
|
||
|
for node in connected_nodes[0]:
|
||
|
neighbor_state_np = inputs_np[node]
|
||
|
|
||
|
# Make State relative to node
|
||
|
_, std = self.env.get_standardize_params(self.state[node.type], node_type=node.type)
|
||
|
std[0:2] = self.env.attention_radius[edge_type_tuple]
|
||
|
neighbor_state_np_st = self.env.standardize(neighbor_state_np,
|
||
|
self.state[node.type],
|
||
|
node_type=node.type,
|
||
|
mean=rel_state,
|
||
|
std=std)
|
||
|
|
||
|
neighbor_state = torch.tensor(neighbor_state_np_st).float().to(self.device)
|
||
|
neighbor_states.append(neighbor_state)
|
||
|
|
||
|
if len(neighbor_states) == 0: # There are no neighbors for edge type # TODO necessary?
|
||
|
neighbor_state_length = int(np.sum([len(entity_dims) for entity_dims in self.state[edge_type[1]].values()]))
|
||
|
edge_states_list.append(torch.zeros((1, 1, neighbor_state_length), device=self.device))
|
||
|
else:
|
||
|
edge_states_list.append(torch.stack(neighbor_states, dim=0))
|
||
|
|
||
|
if self.hyperparams['edge_state_combine_method'] == 'sum':
|
||
|
# Used in Structural-RNN to combine edges as well.
|
||
|
op_applied_edge_states_list = list()
|
||
|
for neighbors_state in edge_states_list:
|
||
|
op_applied_edge_states_list.append(torch.sum(neighbors_state, dim=0))
|
||
|
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
|
||
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
||
|
# Should now be (bs, time, 1)
|
||
|
op_applied_edge_mask_list = list()
|
||
|
for edge_mask in edge_masks:
|
||
|
op_applied_edge_mask_list.append(torch.clamp(torch.sum(edge_mask, dim=0, keepdim=True), max=1.))
|
||
|
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
|
||
|
|
||
|
elif self.hyperparams['edge_state_combine_method'] == 'max':
|
||
|
# Used in NLP, e.g. max over word embeddings in a sentence.
|
||
|
op_applied_edge_states_list = list()
|
||
|
for neighbors_state in edge_states_list:
|
||
|
op_applied_edge_states_list.append(torch.max(neighbors_state, dim=0))
|
||
|
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
|
||
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
||
|
# Should now be (bs, time, 1)
|
||
|
op_applied_edge_mask_list = list()
|
||
|
for edge_mask in edge_masks:
|
||
|
op_applied_edge_mask_list.append(torch.clamp(torch.max(edge_mask, dim=0, keepdim=True), max=1.))
|
||
|
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
|
||
|
|
||
|
elif self.hyperparams['edge_state_combine_method'] == 'mean':
|
||
|
# Used in NLP, e.g. mean over word embeddings in a sentence.
|
||
|
op_applied_edge_states_list = list()
|
||
|
for neighbors_state in edge_states_list:
|
||
|
op_applied_edge_states_list.append(torch.mean(neighbors_state, dim=0))
|
||
|
combined_neighbors = torch.stack(op_applied_edge_states_list, dim=0)
|
||
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
||
|
# Should now be (bs, time, 1)
|
||
|
op_applied_edge_mask_list = list()
|
||
|
for edge_mask in edge_masks:
|
||
|
op_applied_edge_mask_list.append(torch.clamp(torch.mean(edge_mask, dim=0, keepdim=True), max=1.))
|
||
|
combined_edge_masks = torch.stack(op_applied_edge_mask_list, dim=0)
|
||
|
|
||
|
joint_history = torch.cat([combined_neighbors, torch.unsqueeze(inputs_st[self.node], dim=0)], dim=-1)
|
||
|
|
||
|
if edge_type + '/edge_encoder' not in self.curr_hidden_states:
|
||
|
outputs, self.curr_hidden_states[edge_type + '/edge_encoder'] = self.node_modules[
|
||
|
edge_type + '/edge_encoder'](joint_history)
|
||
|
else:
|
||
|
outputs, self.curr_hidden_states[edge_type + '/edge_encoder'] = self.node_modules[
|
||
|
edge_type + '/edge_encoder'](joint_history, self.curr_hidden_states[edge_type + '/edge_encoder'])
|
||
|
|
||
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
||
|
return outputs[:, 0, :] * combined_edge_masks
|
||
|
else:
|
||
|
return outputs[:, 0, :] # [bs, enc_rnn_dim]
|
||
|
|
||
|
def encoder_forward(self, inputs, inputs_st, inputs_np, robot_present_and_future=None):
|
||
|
# Always predicting with the online model.
|
||
|
mode = ModeKeys.PREDICT
|
||
|
|
||
|
self.x = self.obtain_encoded_tensors(mode,
|
||
|
inputs,
|
||
|
inputs_st,
|
||
|
inputs_np,
|
||
|
robot_present_and_future)
|
||
|
|
||
|
self.latent.p_dist = self.p_z_x(mode, self.x)
|
||
|
|
||
|
# robot_future_st is optional here since you can use the same one from encoder_forward,
|
||
|
# but if it's given then we'll re-run that part of the model (if the node is adjacent to the robot).
|
||
|
def decoder_forward(self, prediction_horizon,
|
||
|
num_samples,
|
||
|
robot_present_and_future=None,
|
||
|
z_mode=False,
|
||
|
gmm_mode=False,
|
||
|
full_dist=False,
|
||
|
all_z_sep=False):
|
||
|
# Always predicting with the online model.
|
||
|
mode = ModeKeys.PREDICT
|
||
|
|
||
|
x_nr_t, y_r = None, None
|
||
|
if (self.hyperparams['incl_robot_node']
|
||
|
and self.robot is not None
|
||
|
and robot_present_and_future is not None):
|
||
|
x_nr_t, y_r = self.get_relative_robot_traj(
|
||
|
torch.tensor(self.node.get(np.array([self.node.last_timestep]),
|
||
|
self.state[self.node.type],
|
||
|
padding=0.0),
|
||
|
dtype=torch.float,
|
||
|
device=self.device),
|
||
|
robot_present_and_future,
|
||
|
self.robot.type)
|
||
|
self.x = self.create_encoder_rep(mode, self.TD, x_nr_t, y_r)
|
||
|
self.latent.p_dist = self.p_z_x(mode, self.x)
|
||
|
|
||
|
z, num_samples, num_components = self.latent.sample_p(num_samples,
|
||
|
mode,
|
||
|
most_likely_z=z_mode,
|
||
|
full_dist=full_dist,
|
||
|
all_z_sep=all_z_sep)
|
||
|
|
||
|
_, our_sampled_future = self.p_y_xz(mode, self.x, x_nr_t, y_r, z,
|
||
|
prediction_horizon,
|
||
|
num_samples,
|
||
|
num_components,
|
||
|
gmm_mode)
|
||
|
|
||
|
return our_sampled_future
|