307 lines
15 KiB
Python
307 lines
15 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
from collections import Counter
|
||
|
from model.trajectron import Trajectron
|
||
|
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
|
||
|
from model.model_utils import ModeKeys
|
||
|
from data import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
|
||
|
|
||
|
|
||
|
class OnlineTrajectron(Trajectron):
|
||
|
def __init__(self, model_registrar,
|
||
|
hyperparams, device):
|
||
|
super(OnlineTrajectron, self).__init__(model_registrar=model_registrar,
|
||
|
hyperparams=hyperparams,
|
||
|
log_writer=False,
|
||
|
device=device)
|
||
|
self.node_data = dict()
|
||
|
self.scene_graph = None
|
||
|
self.RING_CAPACITY = max(len(self.hyperparams['edge_removal_filter']),
|
||
|
len(self.hyperparams['edge_addition_filter']),
|
||
|
self.hyperparams['maximum_history_length']) + 1
|
||
|
self.rel_states = dict()
|
||
|
self.removed_nodes = Counter()
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"OnlineTrajectron(# nodes: {len(self.nodes)}, device: {self.device}, hyperparameters: {str(self.hyperparams)}) "
|
||
|
|
||
|
def _add_node_model(self, node):
|
||
|
if node in self.nodes:
|
||
|
raise ValueError('%s was already added to this graph!' % str(node))
|
||
|
|
||
|
self.nodes.add(node)
|
||
|
self.node_models_dict[node] = OnlineMultimodalGenerativeCVAE(self.env,
|
||
|
node,
|
||
|
self.model_registrar,
|
||
|
self.hyperparams,
|
||
|
self.device)
|
||
|
|
||
|
def update_removed_nodes(self):
|
||
|
for node in list(self.removed_nodes.keys()):
|
||
|
if self.removed_nodes[node] >= len(self.hyperparams['edge_removal_filter']):
|
||
|
del self.node_data[node]
|
||
|
del self.removed_nodes[node]
|
||
|
|
||
|
def _remove_node_model(self, node):
|
||
|
if node not in self.nodes:
|
||
|
raise ValueError('%s is not in this graph!' % str(node))
|
||
|
|
||
|
self.nodes.remove(node)
|
||
|
del self.node_models_dict[node]
|
||
|
|
||
|
def set_environment(self, env, init_timestep=0):
|
||
|
self.env = env
|
||
|
self.scene_graph = SceneGraph(edge_radius=self.env.attention_radius)
|
||
|
self.nodes.clear()
|
||
|
self.node_data.clear()
|
||
|
self.node_models_dict.clear()
|
||
|
|
||
|
# Fast-forwarding ourselves to the initial timestep, without running any of the underlying models.
|
||
|
for timestep in range(init_timestep + 1):
|
||
|
self.incremental_forward(self.env.scenes[0].get_clipped_pos_dict(timestep, self.hyperparams['state']),
|
||
|
run_models=False)
|
||
|
|
||
|
def incremental_forward(self, new_inputs_dict,
|
||
|
prediction_horizon=0,
|
||
|
num_samples=0,
|
||
|
robot_present_and_future=None,
|
||
|
z_mode=False,
|
||
|
gmm_mode=False,
|
||
|
full_dist=False,
|
||
|
all_z_sep=False,
|
||
|
run_models=True):
|
||
|
# The way this function works is by appending the new datapoints to the
|
||
|
# ends of each of the LSTMs in the graph. Then, we recalculate the
|
||
|
# encoder's output vector h_x and feed that into the decoder to sample new outputs.
|
||
|
mode = ModeKeys.PREDICT
|
||
|
|
||
|
# No grad since we're predicting always, as evidenced by the line above.
|
||
|
with torch.no_grad():
|
||
|
for node, new_input in new_inputs_dict.items():
|
||
|
if node not in self.node_data:
|
||
|
self.node_data[node] = RingBuffer(capacity=self.RING_CAPACITY,
|
||
|
dtype=(float, len(self.state[node.type]['position'])))
|
||
|
self.node_data[node].append(new_input)
|
||
|
|
||
|
if node in self.removed_nodes:
|
||
|
del self.removed_nodes[node]
|
||
|
|
||
|
# Nodes in self.node_data that aren't in new_inputs_dict were just removed.
|
||
|
newly_removed_nodes = (set(self.node_data.keys()) - set(self.removed_nodes.keys())) - set(
|
||
|
new_inputs_dict.keys())
|
||
|
|
||
|
# We update self.removed_nodes with the newly removed nodes as well as all existing removed nodes to get
|
||
|
# the time since their last removal increased by one.
|
||
|
self.removed_nodes.update(newly_removed_nodes | set(self.removed_nodes.keys()))
|
||
|
|
||
|
# For any nodes that are older than the length of the edge_removal_filter, we can safely clear their data.
|
||
|
self.update_removed_nodes()
|
||
|
|
||
|
# Any remaining removed nodes that aren't yet old enough for data clearing simply have NaNs appended so
|
||
|
# that when it's passed through the LSTMs, the hidden state keeps propagating but the input plays no role
|
||
|
# (the NaNs get converted to zeros later on).
|
||
|
for node in self.removed_nodes:
|
||
|
self.node_data[node].append(np.full((1, 2), np.nan))
|
||
|
|
||
|
for node in self.node_data:
|
||
|
x = self.node_data[node][:, 0]
|
||
|
y = self.node_data[node][:, 1]
|
||
|
vx = derivative_of(x, self.env.scenes[0].dt)
|
||
|
vy = derivative_of(y, self.env.scenes[0].dt)
|
||
|
ax = derivative_of(vx, self.env.scenes[0].dt)
|
||
|
ay = derivative_of(vy, self.env.scenes[0].dt)
|
||
|
new_node_data = np.stack([x, y, vx, vy, ax, ay], axis=-1)
|
||
|
node.overwrite_data(new_node_data,
|
||
|
forward_in_time_on_next_overwrite=(self.node_data[node].shape[0]
|
||
|
== self.RING_CAPACITY))
|
||
|
|
||
|
temp_scene_dict = {k: v[:, 0:2] for k, v in self.node_data.items()}
|
||
|
if not temp_scene_dict:
|
||
|
new_scene_graph = SceneGraph(edge_radius=self.env.attention_radius)
|
||
|
else:
|
||
|
new_scene_graph = TemporalSceneGraph.create_from_temp_scene_dict(
|
||
|
temp_scene_dict,
|
||
|
self.env.attention_radius,
|
||
|
duration=self.RING_CAPACITY,
|
||
|
edge_addition_filter=self.hyperparams['edge_addition_filter'],
|
||
|
edge_removal_filter=self.hyperparams['edge_removal_filter'],
|
||
|
online=True).to_scene_graph(t=self.RING_CAPACITY - 1)
|
||
|
|
||
|
if self.hyperparams['dynamic_edges'] == 'yes':
|
||
|
new_nodes, removed_nodes, new_neighbors, removed_neighbors = new_scene_graph - self.scene_graph
|
||
|
|
||
|
# Aside from updating the scene graph, this for loop updates the graph model
|
||
|
# structure of all affected nodes.
|
||
|
not_removed_nodes = [node for node in self.nodes if node not in removed_nodes]
|
||
|
self.scene_graph = new_scene_graph
|
||
|
for node in not_removed_nodes:
|
||
|
self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors)
|
||
|
|
||
|
# These next 2 for loops add or remove entire node models.
|
||
|
for node in new_nodes:
|
||
|
if node.is_robot:
|
||
|
continue
|
||
|
|
||
|
self._add_node_model(node)
|
||
|
self.node_models_dict[node].update_graph(new_scene_graph, new_neighbors, removed_neighbors)
|
||
|
|
||
|
for node in removed_nodes:
|
||
|
if node.is_robot:
|
||
|
continue
|
||
|
|
||
|
self._remove_node_model(node)
|
||
|
|
||
|
# This actually updates the node models with the newly observed data.
|
||
|
if run_models:
|
||
|
inputs = dict()
|
||
|
inputs_st = dict()
|
||
|
inputs_np = dict()
|
||
|
|
||
|
iter_list = list(self.node_models_dict.keys())
|
||
|
if self.env.scenes[0].robot is not None:
|
||
|
iter_list.append(self.env.scenes[0].robot)
|
||
|
|
||
|
for node in iter_list:
|
||
|
input_np = node.get(np.array([node.last_timestep, node.last_timestep]), self.state[node.type])
|
||
|
|
||
|
_, std = self.env.get_standardize_params(self.state[node.type.name], node.type)
|
||
|
std[0:2] = self.env.attention_radius[(node.type, node.type)]
|
||
|
rel_state = np.zeros_like(input_np)
|
||
|
rel_state[:, 0:2] = input_np[:, 0:2]
|
||
|
input_st = self.env.standardize(input_np,
|
||
|
self.state[node.type.name],
|
||
|
node.type,
|
||
|
mean=rel_state)
|
||
|
self.rel_states[node] = rel_state
|
||
|
|
||
|
# Converting NaNs to zeros.
|
||
|
input_np[np.isnan(input_np)] = 0
|
||
|
input_st[np.isnan(input_st)] = 0
|
||
|
|
||
|
# Convert to torch tensors
|
||
|
inputs[node] = torch.tensor(input_np, dtype=torch.float, device=self.device)
|
||
|
inputs_st[node] = torch.tensor(input_st, dtype=torch.float, device=self.device)
|
||
|
inputs_np[node] = input_np
|
||
|
|
||
|
# We want tensors of shape (1, ph + 1, state_dim) where the first 1 is the batch size.
|
||
|
if (self.hyperparams['incl_robot_node']
|
||
|
and self.env.scenes[0].robot is not None
|
||
|
and robot_present_and_future is not None):
|
||
|
if len(robot_present_and_future.shape) == 2:
|
||
|
robot_present_and_future = robot_present_and_future[np.newaxis, :]
|
||
|
|
||
|
assert robot_present_and_future.shape[1] == prediction_horizon + 1
|
||
|
|
||
|
for node in self.node_models_dict:
|
||
|
self.node_models_dict[node].encoder_forward(inputs,
|
||
|
inputs_st,
|
||
|
inputs_np,
|
||
|
robot_present_and_future)
|
||
|
|
||
|
# If num_predicted_timesteps or num_samples == 0 then do not run the decoder at all,
|
||
|
# just update the encoder LSTMs.
|
||
|
if prediction_horizon == 0 or num_samples == 0:
|
||
|
return
|
||
|
|
||
|
return self.sample_model(prediction_horizon,
|
||
|
num_samples,
|
||
|
robot_present_and_future=robot_present_and_future,
|
||
|
z_mode=z_mode,
|
||
|
gmm_mode=gmm_mode,
|
||
|
full_dist=full_dist,
|
||
|
all_z_sep=all_z_sep)
|
||
|
|
||
|
def _run_decoder(self, node,
|
||
|
num_predicted_timesteps,
|
||
|
num_samples,
|
||
|
robot_present_and_future=None,
|
||
|
z_mode=False,
|
||
|
gmm_mode=False,
|
||
|
full_dist=False,
|
||
|
all_z_sep=False):
|
||
|
model = self.node_models_dict[node]
|
||
|
predictions = model.decoder_forward(num_predicted_timesteps,
|
||
|
num_samples,
|
||
|
robot_present_and_future=robot_present_and_future,
|
||
|
z_mode=z_mode,
|
||
|
gmm_mode=gmm_mode,
|
||
|
full_dist=full_dist,
|
||
|
all_z_sep=all_z_sep)
|
||
|
|
||
|
predictions_uns = self.env.unstandardize(predictions.cpu().detach().numpy(),
|
||
|
self.pred_state[node.type.name],
|
||
|
node.type,
|
||
|
mean=self.rel_states[node][..., 0:2])
|
||
|
|
||
|
# Return will be of shape (batch_size, num_samples, num_predicted_timesteps, 2)
|
||
|
return np.transpose(predictions_uns, (1, 0, 2, 3))
|
||
|
|
||
|
def sample_model(self, num_predicted_timesteps,
|
||
|
num_samples,
|
||
|
robot_present_and_future=None,
|
||
|
z_mode=False,
|
||
|
gmm_mode=False,
|
||
|
full_dist=False,
|
||
|
all_z_sep=False):
|
||
|
# Just start from the encoder output (minus the
|
||
|
# robot future) and get num_samples of
|
||
|
# num_predicted_timesteps-length trajectories.
|
||
|
if num_predicted_timesteps == 0 or num_samples == 0:
|
||
|
return
|
||
|
|
||
|
mode = ModeKeys.PREDICT
|
||
|
|
||
|
# We want tensors of shape (1, ph + 1, state_dim) where the first 1 is the batch size.
|
||
|
if self.hyperparams['incl_robot_node'] and self.env.scenes[
|
||
|
0].robot is not None and robot_present_and_future is not None:
|
||
|
if len(robot_present_and_future.shape) == 2:
|
||
|
robot_present_and_future = robot_present_and_future[np.newaxis, :]
|
||
|
|
||
|
assert robot_present_and_future.shape[1] == num_predicted_timesteps + 1
|
||
|
|
||
|
# No grad since we're predicting always, as evidenced by the line above.
|
||
|
with torch.no_grad():
|
||
|
predictions_dict = dict()
|
||
|
for node in set(self.nodes) - set(self.removed_nodes.keys()):
|
||
|
if node.is_robot:
|
||
|
continue
|
||
|
|
||
|
predictions_dict[node] = self._run_decoder(node, num_predicted_timesteps,
|
||
|
num_samples,
|
||
|
robot_present_and_future,
|
||
|
z_mode,
|
||
|
gmm_mode,
|
||
|
full_dist,
|
||
|
all_z_sep)
|
||
|
|
||
|
return predictions_dict
|
||
|
|
||
|
def forward(self, init_env,
|
||
|
init_timestep,
|
||
|
pos_dicts, # After the initial environment
|
||
|
num_predicted_timesteps,
|
||
|
num_samples,
|
||
|
robot_present_and_future=None,
|
||
|
z_mode=False,
|
||
|
gmm_mode=False,
|
||
|
full_dist=False,
|
||
|
all_z_sep=False):
|
||
|
# This is the standard forward prediction function,
|
||
|
# if you have some historical data and just want to
|
||
|
# predict forward some number of timesteps.
|
||
|
|
||
|
# Setting us back to the initial scene graph we had.
|
||
|
self.set_environment(init_env, init_timestep)
|
||
|
|
||
|
# Looping through and applying updates to the model.
|
||
|
for i in range(len(pos_dicts)):
|
||
|
self.incremental_forward(pos_dicts[i])
|
||
|
|
||
|
return self.sample_model(num_predicted_timesteps,
|
||
|
num_samples,
|
||
|
robot_present_and_future=robot_present_and_future,
|
||
|
z_mode=z_mode,
|
||
|
gmm_mode=gmm_mode,
|
||
|
full_dist=full_dist,
|
||
|
all_z_sep=all_z_sep)
|