64 lines
No EOL
2.7 KiB
Python
64 lines
No EOL
2.7 KiB
Python
import torch
|
|
from model.dynamics import Dynamic
|
|
from utils import block_diag
|
|
from model.components import GMM2D
|
|
|
|
|
|
class SingleIntegrator(Dynamic):
|
|
def init_constants(self):
|
|
self.F = torch.eye(4, device=self.device, dtype=torch.float32)
|
|
self.F[0:2, 2:] = torch.eye(2, device=self.device, dtype=torch.float32) * self.dt
|
|
self.F_t = self.F.transpose(-2, -1)
|
|
|
|
def integrate_samples(self, v, x=None):
|
|
"""
|
|
Integrates deterministic samples of velocity.
|
|
|
|
:param v: Velocity samples
|
|
:param x: Not used for SI.
|
|
:return: Position samples
|
|
"""
|
|
p_0 = self.initial_conditions['pos'].unsqueeze(1)
|
|
return torch.cumsum(v, dim=2) * self.dt + p_0
|
|
|
|
def integrate_distribution(self, v_dist, x=None):
|
|
r"""
|
|
Integrates the GMM velocity distribution to a distribution over position.
|
|
The Kalman Equations are used.
|
|
|
|
.. math:: \mu_{t+1} =\textbf{F} \mu_{t}
|
|
|
|
.. math:: \mathbf{\Sigma}_{t+1}={\textbf {F}} \mathbf{\Sigma}_{t} {\textbf {F}}^{T}
|
|
|
|
.. math::
|
|
\textbf{F} = \left[
|
|
\begin{array}{cccc}
|
|
\sigma_x^2 & \rho_p \sigma_x \sigma_y & 0 & 0 \\
|
|
\rho_p \sigma_x \sigma_y & \sigma_y^2 & 0 & 0 \\
|
|
0 & 0 & \sigma_{v_x}^2 & \rho_v \sigma_{v_x} \sigma_{v_y} \\
|
|
0 & 0 & \rho_v \sigma_{v_x} \sigma_{v_y} & \sigma_{v_y}^2 \\
|
|
\end{array}
|
|
\right]_{t}
|
|
|
|
:param v_dist: Joint GMM Distribution over velocity in x and y direction.
|
|
:param x: Not used for SI.
|
|
:return: Joint GMM Distribution over position in x and y direction.
|
|
"""
|
|
p_0 = self.initial_conditions['pos'].unsqueeze(1)
|
|
ph = v_dist.mus.shape[-3]
|
|
sample_batch_dim = list(v_dist.mus.shape[0:2])
|
|
pos_dist_sigma_matrix_list = []
|
|
|
|
pos_mus = p_0[:, None] + torch.cumsum(v_dist.mus, dim=2) * self.dt
|
|
|
|
vel_dist_sigma_matrix = v_dist.get_covariance_matrix()
|
|
pos_dist_sigma_matrix_t = torch.zeros(sample_batch_dim + [v_dist.components, 2, 2], device=self.device)
|
|
|
|
for t in range(ph):
|
|
vel_sigma_matrix_t = vel_dist_sigma_matrix[:, :, t]
|
|
full_sigma_matrix_t = block_diag([pos_dist_sigma_matrix_t, vel_sigma_matrix_t])
|
|
pos_dist_sigma_matrix_t = self.F[..., :2, :].matmul(full_sigma_matrix_t.matmul(self.F_t)[..., :2])
|
|
pos_dist_sigma_matrix_list.append(pos_dist_sigma_matrix_t)
|
|
|
|
pos_dist_sigma_matrix = torch.stack(pos_dist_sigma_matrix_list, dim=2)
|
|
return GMM2D.from_log_pis_mus_cov_mats(v_dist.log_pis, pos_mus, pos_dist_sigma_matrix) |