dynamic thresholding (maybe?)

This commit is contained in:
rromb 2022-06-05 19:22:07 +02:00
parent 64cc718ead
commit a0f390afb6
2 changed files with 40 additions and 5 deletions

View file

@ -4,6 +4,7 @@ import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from functools import partial from functools import partial
from einops import rearrange
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
@ -73,8 +74,8 @@ class DDIMSampler(object):
x_T=None, x_T=None,
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1., unconditional_guidance_scale=1.,
unconditional_conditioning=None, unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... dynamic_threshold=None,
**kwargs **kwargs
): ):
if conditioning is not None: if conditioning is not None:
@ -106,6 +107,7 @@ class DDIMSampler(object):
log_every_t=log_every_t, log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
) )
return samples, intermediates return samples, intermediates
@ -115,7 +117,7 @@ class DDIMSampler(object):
callback=None, timesteps=None, quantize_denoised=False, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100, mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,): unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None):
device = self.model.betas.device device = self.model.betas.device
b = shape[0] b = shape[0]
if x_T is None: if x_T is None:
@ -150,7 +152,8 @@ class DDIMSampler(object):
noise_dropout=noise_dropout, score_corrector=score_corrector, noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning) unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs img, pred_x0 = outs
if callback: callback(i) if callback: callback(i)
if img_callback: img_callback(pred_x0, i) if img_callback: img_callback(pred_x0, i)
@ -164,7 +167,8 @@ class DDIMSampler(object):
@torch.no_grad() @torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None): unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
@ -194,6 +198,33 @@ class DDIMSampler(object):
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
# renorm
pred_max = pred_x0.max()
pred_min = pred_x0.min()
pred_x0 = (pred_x0-pred_min)/(pred_max-pred_min) # 0 ... 1
pred_x0 = 2*pred_x0 - 1. # -1 ... 1
s = torch.quantile(
rearrange(pred_x0, 'b ... -> b (...)').abs(),
dynamic_threshold,
dim=-1
)
s.clamp_(min=1.0)
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
# clip by threshold
#pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
# temporary hack: numpy on cpu
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
# re.renorm
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
pred_x0 = (pred_max-pred_min)*pred_x0 + pred_min # orig range
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature

View file

@ -12,6 +12,10 @@ class AbstractEncoder(nn.Module):
def encode(self, *args, **kwargs): def encode(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module): class ClassEmbedder(nn.Module):