223 lines
8.5 KiB
Python
223 lines
8.5 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from model.dynamics import Dynamic
|
||
|
from utils import block_diag
|
||
|
from model.components import GMM2D
|
||
|
|
||
|
|
||
|
class Unicycle(Dynamic):
|
||
|
def init_constants(self):
|
||
|
self.F_s = torch.eye(4, device=self.device, dtype=torch.float32)
|
||
|
self.F_s[0:2, 2:] = torch.eye(2, device=self.device, dtype=torch.float32) * self.dt
|
||
|
self.F_s_t = self.F_s.transpose(-2, -1)
|
||
|
|
||
|
def create_graph(self, xz_size):
|
||
|
model_if_absent = nn.Linear(xz_size + 1, 1)
|
||
|
self.p0_model = self.model_registrar.get_model(f"{self.node_type}/unicycle_initializer", model_if_absent)
|
||
|
|
||
|
def dynamic(self, x, u):
|
||
|
r"""
|
||
|
TODO: Boris: Add docstring
|
||
|
:param x:
|
||
|
:param u:
|
||
|
:return:
|
||
|
"""
|
||
|
x_p = x[0]
|
||
|
y_p = x[1]
|
||
|
phi = x[2]
|
||
|
v = x[3]
|
||
|
dphi = u[0]
|
||
|
a = u[1]
|
||
|
|
||
|
mask = torch.abs(dphi) <= 1e-2
|
||
|
dphi = ~mask * dphi + (mask) * 1
|
||
|
|
||
|
phi_p_omega_dt = phi + dphi * self.dt
|
||
|
dsin_domega = (torch.sin(phi_p_omega_dt) - torch.sin(phi)) / dphi
|
||
|
dcos_domega = (torch.cos(phi_p_omega_dt) - torch.cos(phi)) / dphi
|
||
|
|
||
|
d1 = torch.stack([(x_p
|
||
|
+ (a / dphi) * dcos_domega
|
||
|
+ v * dsin_domega
|
||
|
+ (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt),
|
||
|
(y_p
|
||
|
- v * dcos_domega
|
||
|
+ (a / dphi) * dsin_domega
|
||
|
- (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt),
|
||
|
phi + dphi * self.dt,
|
||
|
v + a * self.dt], dim=0)
|
||
|
d2 = torch.stack([x_p + v * torch.cos(phi) * self.dt + (a / 2) * torch.cos(phi) * self.dt ** 2,
|
||
|
y_p + v * torch.sin(phi) * self.dt + (a / 2) * torch.sin(phi) * self.dt ** 2,
|
||
|
phi * torch.ones_like(a),
|
||
|
v + a * self.dt], dim=0)
|
||
|
return torch.where(~mask, d1, d2)
|
||
|
|
||
|
def integrate_samples(self, control_samples, x=None):
|
||
|
r"""
|
||
|
TODO: Boris: Add docstring
|
||
|
:param x:
|
||
|
:param u:
|
||
|
:return:
|
||
|
"""
|
||
|
ph = control_samples.shape[-2]
|
||
|
p_0 = self.initial_conditions['pos'].unsqueeze(1)
|
||
|
v_0 = self.initial_conditions['vel'].unsqueeze(1)
|
||
|
phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0])
|
||
|
|
||
|
phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1)))
|
||
|
|
||
|
u = torch.stack([control_samples[..., 0], control_samples[..., 1]], dim=0)
|
||
|
x = torch.stack([p_0[..., 0], p_0[..., 1], phi_0, torch.norm(v_0, dim=-1)], dim = 0).squeeze(dim=-1)
|
||
|
|
||
|
mus_list = []
|
||
|
for t in range(ph):
|
||
|
x = self.dynamic(x, u[..., t])
|
||
|
mus_list.append(torch.stack((x[0], x[1]), dim=-1))
|
||
|
|
||
|
pos_mus = torch.stack(mus_list, dim=2)
|
||
|
return pos_mus
|
||
|
|
||
|
def compute_control_jacobian(self, sample_batch_dim, components, x, u):
|
||
|
r"""
|
||
|
TODO: Boris: Add docstring
|
||
|
:param x:
|
||
|
:param u:
|
||
|
:return:
|
||
|
"""
|
||
|
F = torch.zeros(sample_batch_dim + [components, 4, 2],
|
||
|
device=self.device,
|
||
|
dtype=torch.float32)
|
||
|
|
||
|
phi = x[2]
|
||
|
v = x[3]
|
||
|
dphi = u[0]
|
||
|
a = u[1]
|
||
|
|
||
|
mask = torch.abs(dphi) <= 1e-2
|
||
|
dphi = ~mask * dphi + (mask) * 1
|
||
|
|
||
|
phi_p_omega_dt = phi + dphi * self.dt
|
||
|
dsin_domega = (torch.sin(phi_p_omega_dt) - torch.sin(phi)) / dphi
|
||
|
dcos_domega = (torch.cos(phi_p_omega_dt) - torch.cos(phi)) / dphi
|
||
|
|
||
|
F[..., 0, 0] = ((v / dphi) * torch.cos(phi_p_omega_dt) * self.dt
|
||
|
- (v / dphi) * dsin_domega
|
||
|
- (2 * a / dphi ** 2) * torch.sin(phi_p_omega_dt) * self.dt
|
||
|
- (2 * a / dphi ** 2) * dcos_domega
|
||
|
+ (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt ** 2)
|
||
|
F[..., 0, 1] = (1 / dphi) * dcos_domega + (1 / dphi) * torch.sin(phi_p_omega_dt) * self.dt
|
||
|
|
||
|
F[..., 1, 0] = ((v / dphi) * dcos_domega
|
||
|
- (2 * a / dphi ** 2) * dsin_domega
|
||
|
+ (2 * a / dphi ** 2) * torch.cos(phi_p_omega_dt) * self.dt
|
||
|
+ (v / dphi) * torch.sin(phi_p_omega_dt) * self.dt
|
||
|
+ (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt ** 2)
|
||
|
F[..., 1, 1] = (1 / dphi) * dsin_domega - (1 / dphi) * torch.cos(phi_p_omega_dt) * self.dt
|
||
|
|
||
|
F[..., 2, 0] = self.dt
|
||
|
|
||
|
F[..., 3, 1] = self.dt
|
||
|
|
||
|
F_sm = torch.zeros(sample_batch_dim + [components, 4, 2],
|
||
|
device=self.device,
|
||
|
dtype=torch.float32)
|
||
|
|
||
|
F_sm[..., 0, 1] = (torch.cos(phi) * self.dt ** 2) / 2
|
||
|
|
||
|
F_sm[..., 1, 1] = (torch.sin(phi) * self.dt ** 2) / 2
|
||
|
|
||
|
F_sm[..., 3, 1] = self.dt
|
||
|
|
||
|
return torch.where(~mask.unsqueeze(-1).unsqueeze(-1), F, F_sm)
|
||
|
|
||
|
def compute_jacobian(self, sample_batch_dim, components, x, u):
|
||
|
r"""
|
||
|
TODO: Boris: Add docstring
|
||
|
:param x:
|
||
|
:param u:
|
||
|
:return:
|
||
|
"""
|
||
|
one = torch.tensor(1)
|
||
|
F = torch.zeros(sample_batch_dim + [components, 4, 4],
|
||
|
device=self.device,
|
||
|
dtype=torch.float32)
|
||
|
|
||
|
phi = x[2]
|
||
|
v = x[3]
|
||
|
dphi = u[0]
|
||
|
a = u[1]
|
||
|
|
||
|
mask = torch.abs(dphi) <= 1e-2
|
||
|
dphi = ~mask * dphi + (mask) * 1
|
||
|
|
||
|
phi_p_omega_dt = phi + dphi * self.dt
|
||
|
dsin_domega = (torch.sin(phi_p_omega_dt) - torch.sin(phi)) / dphi
|
||
|
dcos_domega = (torch.cos(phi_p_omega_dt) - torch.cos(phi)) / dphi
|
||
|
|
||
|
F[..., 0, 0] = one
|
||
|
F[..., 1, 1] = one
|
||
|
F[..., 2, 2] = one
|
||
|
F[..., 3, 3] = one
|
||
|
|
||
|
F[..., 0, 2] = v * dcos_domega - (a / dphi) * dsin_domega + (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt
|
||
|
F[..., 0, 3] = dsin_domega
|
||
|
|
||
|
F[..., 1, 2] = v * dsin_domega + (a / dphi) * dcos_domega + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt
|
||
|
F[..., 1, 3] = -dcos_domega
|
||
|
|
||
|
F_sm = torch.zeros(sample_batch_dim + [components, 4, 4],
|
||
|
device=self.device,
|
||
|
dtype=torch.float32)
|
||
|
|
||
|
F_sm[..., 0, 0] = one
|
||
|
F_sm[..., 1, 1] = one
|
||
|
F_sm[..., 2, 2] = one
|
||
|
F_sm[..., 3, 3] = one
|
||
|
|
||
|
F_sm[..., 0, 2] = -v * torch.sin(phi) * self.dt - (a * torch.sin(phi) * self.dt ** 2) / 2
|
||
|
F_sm[..., 0, 3] = torch.cos(phi) * self.dt
|
||
|
|
||
|
F_sm[..., 1, 2] = v * torch.cos(phi) * self.dt + (a * torch.cos(phi) * self.dt ** 2) / 2
|
||
|
F_sm[..., 1, 3] = torch.sin(phi) * self.dt
|
||
|
|
||
|
return torch.where(~mask.unsqueeze(-1).unsqueeze(-1), F, F_sm)
|
||
|
|
||
|
def integrate_distribution(self, control_dist_dphi_a, x):
|
||
|
r"""
|
||
|
TODO: Boris: Add docstring
|
||
|
:param x:
|
||
|
:param u:
|
||
|
:return:
|
||
|
"""
|
||
|
sample_batch_dim = list(control_dist_dphi_a.mus.shape[0:2])
|
||
|
ph = control_dist_dphi_a.mus.shape[-3]
|
||
|
p_0 = self.initial_conditions['pos'].unsqueeze(1)
|
||
|
v_0 = self.initial_conditions['vel'].unsqueeze(1)
|
||
|
phi_0 = torch.atan2(v_0[..., 1], v_0[..., 0])
|
||
|
|
||
|
phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1)))
|
||
|
|
||
|
dist_sigma_matrix = control_dist_dphi_a.get_covariance_matrix()
|
||
|
pos_dist_sigma_matrix_t = torch.zeros(sample_batch_dim + [control_dist_dphi_a.components, 4, 4],
|
||
|
device=self.device)
|
||
|
|
||
|
u = torch.stack([control_dist_dphi_a.mus[..., 0], control_dist_dphi_a.mus[..., 1]], dim=0)
|
||
|
x = torch.stack([p_0[..., 0], p_0[..., 1], phi_0, torch.norm(v_0, dim=-1)], dim=0)
|
||
|
|
||
|
pos_dist_sigma_matrix_list = []
|
||
|
mus_list = []
|
||
|
for t in range(ph):
|
||
|
F_t = self.compute_jacobian(sample_batch_dim, control_dist_dphi_a.components, x, u[:, :, :, t])
|
||
|
G_t = self.compute_control_jacobian(sample_batch_dim, control_dist_dphi_a.components, x, u[:, :, :, t])
|
||
|
dist_sigma_matrix_t = dist_sigma_matrix[:, :, t]
|
||
|
pos_dist_sigma_matrix_t = (F_t.matmul(pos_dist_sigma_matrix_t.matmul(F_t.transpose(-2, -1)))
|
||
|
+ G_t.matmul(dist_sigma_matrix_t.matmul(G_t.transpose(-2, -1))))
|
||
|
pos_dist_sigma_matrix_list.append(pos_dist_sigma_matrix_t[..., :2, :2])
|
||
|
|
||
|
x = self.dynamic(x, u[:, :, :, t])
|
||
|
mus_list.append(torch.stack((x[0], x[1]), dim=-1))
|
||
|
|
||
|
pos_dist_sigma_matrix = torch.stack(pos_dist_sigma_matrix_list, dim=2)
|
||
|
pos_mus = torch.stack(mus_list, dim=2)
|
||
|
return GMM2D.from_log_pis_mus_cov_mats(control_dist_dphi_a.log_pis, pos_mus, pos_dist_sigma_matrix)
|