108 lines
4.5 KiB
Python
108 lines
4.5 KiB
Python
|
import torch
|
||
|
import torch.distributions as td
|
||
|
import numpy as np
|
||
|
from model.model_utils import ModeKeys, tile
|
||
|
|
||
|
|
||
|
class DiscreteLatent(object):
|
||
|
def __init__(self, hyperparams, device):
|
||
|
self.hyperparams = hyperparams
|
||
|
self.z_dim = hyperparams['N'] * hyperparams['K']
|
||
|
self.N = hyperparams['N']
|
||
|
self.K = hyperparams['K']
|
||
|
self.kl_min = hyperparams['kl_min']
|
||
|
self.device = device
|
||
|
self.temp = None # filled in by MultimodalGenerativeCVAE.set_annealing_params
|
||
|
self.z_logit_clip = None # filled in by MultimodalGenerativeCVAE.set_annealing_params
|
||
|
self.p_dist = None # filled in by MultimodalGenerativeCVAE.encoder
|
||
|
self.q_dist = None # filled in by MultimodalGenerativeCVAE.encoder
|
||
|
|
||
|
def dist_from_h(self, h, mode):
|
||
|
logits_separated = torch.reshape(h, (-1, self.N, self.K))
|
||
|
logits_separated_mean_zero = logits_separated - torch.mean(logits_separated, dim=-1, keepdim=True)
|
||
|
if self.z_logit_clip is not None and mode == ModeKeys.TRAIN:
|
||
|
c = self.z_logit_clip
|
||
|
logits = torch.clamp(logits_separated_mean_zero, min=-c, max=c)
|
||
|
else:
|
||
|
logits = logits_separated_mean_zero
|
||
|
|
||
|
return td.OneHotCategorical(logits=logits)
|
||
|
|
||
|
def sample_q(self, k, mode):
|
||
|
if mode == ModeKeys.TRAIN:
|
||
|
z_dist = td.RelaxedOneHotCategorical(self.temp, logits=self.q_dist.logits)
|
||
|
z_NK = z_dist.rsample((k,))
|
||
|
elif mode == ModeKeys.EVAL:
|
||
|
z_NK = self.q_dist.sample((k,))
|
||
|
return torch.reshape(z_NK, (k, -1, self.z_dim))
|
||
|
|
||
|
def sample_p(self, num_samples_z, mode, num_samples_gmm=1, most_likely=False, all_z=False):
|
||
|
if all_z:
|
||
|
bs = self.p_dist.probs.size()[0]
|
||
|
z_NK = torch.from_numpy(self.all_one_hot_combinations(self.N, self.K)).float().to(self.device).repeat(1, bs)
|
||
|
num_samples_z = self.K ** self.N
|
||
|
|
||
|
elif most_likely:
|
||
|
# Sampling the most likely z from p(z|x).
|
||
|
eye_mat = torch.eye(self.p_dist.event_shape[-1], device=self.device)
|
||
|
argmax_idxs = torch.argmax(self.p_dist.probs, dim=2)
|
||
|
z_NK = torch.unsqueeze(eye_mat[argmax_idxs], dim=0).expand(num_samples_z, -1, -1, -1)
|
||
|
|
||
|
else:
|
||
|
z_NK = self.p_dist.sample((num_samples_z,))
|
||
|
|
||
|
z_NK = tile(z_NK, 0, num_samples_gmm, device=self.device)
|
||
|
k = num_samples_z * num_samples_gmm
|
||
|
|
||
|
if mode == ModeKeys.PREDICT:
|
||
|
return torch.reshape(z_NK, (k, -1, self.N * self.K)), num_samples_z
|
||
|
else:
|
||
|
return torch.reshape(z_NK, (k, -1, self.N * self.K))
|
||
|
|
||
|
def kl_q_p(self, log_writer=None, prefix=None, curr_iter=None):
|
||
|
kl_separated = td.kl_divergence(self.q_dist, self.p_dist)
|
||
|
if len(kl_separated.size()) < 2:
|
||
|
kl_separated = torch.unsqueeze(kl_separated, dim=0)
|
||
|
|
||
|
kl_minibatch = torch.mean(kl_separated, dim=0, keepdim=True)
|
||
|
|
||
|
if log_writer is not None:
|
||
|
log_writer.add_scalar(prefix + '/true_kl', torch.sum(kl_minibatch), curr_iter)
|
||
|
|
||
|
if self.kl_min > 0:
|
||
|
kl_lower_bounded = torch.clamp(kl_minibatch, min=self.kl_min)
|
||
|
kl = torch.sum(kl_lower_bounded)
|
||
|
else:
|
||
|
kl = torch.sum(kl_minibatch)
|
||
|
|
||
|
return kl
|
||
|
|
||
|
def q_log_prob(self, z):
|
||
|
k = z.size()[0]
|
||
|
z_NK = torch.reshape(z, [k, -1, self.N, self.K])
|
||
|
return torch.sum(self.q_dist.log_prob(z_NK), dim=2)
|
||
|
|
||
|
def p_log_prob(self, z):
|
||
|
k = z.size()[0]
|
||
|
z_NK = torch.reshape(z, [k, -1, self.N, self.K])
|
||
|
return torch.sum(self.p_dist.log_prob(z_NK), dim=2)
|
||
|
|
||
|
def get_p_dist_probs(self):
|
||
|
return self.p_dist.probs
|
||
|
|
||
|
@staticmethod
|
||
|
def all_one_hot_combinations(N, K):
|
||
|
return np.eye(K).take(np.reshape(np.indices([K] * N), [N, -1]).T, axis=0).reshape(-1, N * K) # [K**N, N*K]
|
||
|
|
||
|
def summarize_for_tensorboard(self, log_writer, prefix, curr_iter):
|
||
|
log_writer.add_histogram(prefix + "/latent/p_z_x", self.p_dist.probs, curr_iter)
|
||
|
log_writer.add_histogram(prefix + "/latent/q_z_xy", self.q_dist.probs, curr_iter)
|
||
|
log_writer.add_histogram(prefix + "/latent/p_z_x_logits", self.p_dist.logits, curr_iter)
|
||
|
log_writer.add_histogram(prefix + "/latent/q_z_xy_logits", self.q_dist.logits, curr_iter)
|
||
|
if self.z_dim <= 9:
|
||
|
for i in range(self.N):
|
||
|
for j in range(self.K):
|
||
|
log_writer.add_histogram(prefix + "/latent/q_z_xy_logit{0}{1}".format(i, j),
|
||
|
self.q_dist.logits[:, i, j],
|
||
|
curr_iter)
|