Trajectron-plus-plus/trajectron/model/dynamics/single_integrator.py
Ruben van de Ven dfa1d43f2e Change imports to support usage as module
Note that this does require rerunning the `process_data.py` scripts in the example folders so that the dill-files are updated.
2023-10-09 20:27:29 +02:00

64 lines
No EOL
2.7 KiB
Python

import torch
from trajectron.model.dynamics import Dynamic
from trajectron.utils import block_diag
from trajectron.model.components import GMM2D
class SingleIntegrator(Dynamic):
def init_constants(self):
self.F = torch.eye(4, device=self.device, dtype=torch.float32)
self.F[0:2, 2:] = torch.eye(2, device=self.device, dtype=torch.float32) * self.dt
self.F_t = self.F.transpose(-2, -1)
def integrate_samples(self, v, x=None):
"""
Integrates deterministic samples of velocity.
:param v: Velocity samples
:param x: Not used for SI.
:return: Position samples
"""
p_0 = self.initial_conditions['pos'].unsqueeze(1)
return torch.cumsum(v, dim=2) * self.dt + p_0
def integrate_distribution(self, v_dist, x=None):
r"""
Integrates the GMM velocity distribution to a distribution over position.
The Kalman Equations are used.
.. math:: \mu_{t+1} =\textbf{F} \mu_{t}
.. math:: \mathbf{\Sigma}_{t+1}={\textbf {F}} \mathbf{\Sigma}_{t} {\textbf {F}}^{T}
.. math::
\textbf{F} = \left[
\begin{array}{cccc}
\sigma_x^2 & \rho_p \sigma_x \sigma_y & 0 & 0 \\
\rho_p \sigma_x \sigma_y & \sigma_y^2 & 0 & 0 \\
0 & 0 & \sigma_{v_x}^2 & \rho_v \sigma_{v_x} \sigma_{v_y} \\
0 & 0 & \rho_v \sigma_{v_x} \sigma_{v_y} & \sigma_{v_y}^2 \\
\end{array}
\right]_{t}
:param v_dist: Joint GMM Distribution over velocity in x and y direction.
:param x: Not used for SI.
:return: Joint GMM Distribution over position in x and y direction.
"""
p_0 = self.initial_conditions['pos'].unsqueeze(1)
ph = v_dist.mus.shape[-3]
sample_batch_dim = list(v_dist.mus.shape[0:2])
pos_dist_sigma_matrix_list = []
pos_mus = p_0[:, None] + torch.cumsum(v_dist.mus, dim=2) * self.dt
vel_dist_sigma_matrix = v_dist.get_covariance_matrix()
pos_dist_sigma_matrix_t = torch.zeros(sample_batch_dim + [v_dist.components, 2, 2], device=self.device)
for t in range(ph):
vel_sigma_matrix_t = vel_dist_sigma_matrix[:, :, t]
full_sigma_matrix_t = block_diag([pos_dist_sigma_matrix_t, vel_sigma_matrix_t])
pos_dist_sigma_matrix_t = self.F[..., :2, :].matmul(full_sigma_matrix_t.matmul(self.F_t)[..., :2])
pos_dist_sigma_matrix_list.append(pos_dist_sigma_matrix_t)
pos_dist_sigma_matrix = torch.stack(pos_dist_sigma_matrix_list, dim=2)
return GMM2D.from_log_pis_mus_cov_mats(v_dist.log_pis, pos_mus, pos_dist_sigma_matrix)