Trajectron-plus-plus/code/model/components/gmm2d.py

61 lines
2.9 KiB
Python

import torch
import torch.distributions as td
import numpy as np
from model.model_utils import to_one_hot
class GMM2D(object):
def __init__(self, log_pis, mus, log_sigmas, corrs, pred_state_length, device,
clip_lo=-10, clip_hi=10):
self.device = device
self.pred_state_length = pred_state_length
# input shapes
# pis: [..., GMM_c]
# mus: [..., GMM_c*2]
# sigmas: [..., GMM_c*2]
# corrs: [..., GMM_c]
GMM_c = log_pis.shape[-1]
# Sigma = [s1^2 p*s1*s2 L = [s1 0
# p*s1*s2 s2^2 ] p*s2 sqrt(1-p^2)*s2]
log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True)
mus = self.reshape_to_components(mus, GMM_c) # [..., GMM_c, 2]
log_sigmas = self.reshape_to_components(torch.clamp(log_sigmas, min=clip_lo, max=clip_hi), GMM_c)
sigmas = torch.exp(log_sigmas) # [..., GMM_c, 2]
one_minus_rho2 = 1 - corrs**2 # [..., GMM_c]
self.L1 = sigmas*torch.stack([torch.ones_like(corrs, device=self.device), corrs], dim=-1)
self.L2 = sigmas*torch.stack([torch.zeros_like(corrs, device=self.device), torch.sqrt(one_minus_rho2)], dim=-1)
self.batch_shape = log_pis.shape[:-1]
self.GMM_c = GMM_c
self.log_pis = log_pis # [..., GMM_c]
self.mus = mus # [..., GMM_c, 2]
self.log_sigmas = log_sigmas # [..., GMM_c, 2]
self.sigmas = sigmas # [..., GMM_c, 2]
self.corrs = corrs # [..., GMM_c]
self.one_minus_rho2 = one_minus_rho2 # [..., GMM_c]
self.cat = td.Categorical(logits=log_pis)
def sample(self):
MVN_samples = (self.mus
+ self.L1*torch.unsqueeze(torch.randn_like(self.corrs, device=self.device), dim=-1) # [..., GMM_c, 2]
+ self.L2*torch.unsqueeze(torch.randn_like(self.corrs, device=self.device), dim=-1)) # (manual 2x2 matmul)
cat_samples = self.cat.sample() # [...]
selector = torch.unsqueeze(to_one_hot(cat_samples, self.GMM_c, self.device), dim=-1)
return torch.sum(MVN_samples*selector, dim=-2)
def log_prob(self, x):
# x: [..., 2]
x = torch.unsqueeze(x, dim=-2) # [..., 1, 2]
dx = x - self.mus # [..., GMM_c, 2]
z = (torch.sum((dx/self.sigmas)**2, dim=-1) -
2*self.corrs*torch.prod(dx, dim=-1)/torch.prod(self.sigmas, dim=-1)) # [..., GMM_c]
component_log_p = -(torch.log(self.one_minus_rho2) + 2*torch.sum(self.log_sigmas, dim=-1) +
z/self.one_minus_rho2 +
2*np.log(2*np.pi))/2
return torch.logsumexp(self.log_pis + component_log_p, dim=-1)
def reshape_to_components(self, tensor, GMM_c):
return torch.reshape(tensor, list(tensor.shape[:-1]) + [GMM_c, self.pred_state_length])