From a0f390afb65f6d1a21f950c827d29cb76ae80aaf Mon Sep 17 00:00:00 2001 From: rromb Date: Sun, 5 Jun 2022 19:22:07 +0200 Subject: [PATCH] dynamic thresholding (maybe?) --- ldm/models/diffusion/ddim.py | 41 +++++++++++++++++++++++++++++---- ldm/modules/encoders/modules.py | 4 ++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index edf1eaf..7d6cb48 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -4,6 +4,7 @@ import torch import numpy as np from tqdm import tqdm from functools import partial +from einops import rearrange from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like @@ -73,8 +74,8 @@ class DDIMSampler(object): x_T=None, log_every_t=100, unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, **kwargs ): if conditioning is not None: @@ -106,6 +107,7 @@ class DDIMSampler(object): log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, ) return samples, intermediates @@ -115,7 +117,7 @@ class DDIMSampler(object): callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -150,7 +152,8 @@ class DDIMSampler(object): noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) @@ -164,7 +167,8 @@ class DDIMSampler(object): @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None): + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): b, *_, device = *x.shape, x.device if unconditional_conditioning is None or unconditional_guidance_scale == 1.: @@ -194,6 +198,33 @@ class DDIMSampler(object): pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + # renorm + pred_max = pred_x0.max() + pred_min = pred_x0.min() + pred_x0 = (pred_x0-pred_min)/(pred_max-pred_min) # 0 ... 1 + pred_x0 = 2*pred_x0 - 1. # -1 ... 1 + + s = torch.quantile( + rearrange(pred_x0, 'b ... -> b (...)').abs(), + dynamic_threshold, + dim=-1 + ) + s.clamp_(min=1.0) + s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) + + # clip by threshold + #pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max + + # temporary hack: numpy on cpu + pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() + pred_x0 = torch.tensor(pred_x0).to(self.model.device) + + # re.renorm + pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 + pred_x0 = (pred_max-pred_min)*pred_x0 + pred_min # orig range + # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index a62b8e7..68260c3 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -12,6 +12,10 @@ class AbstractEncoder(nn.Module): def encode(self, *args, **kwargs): raise NotImplementedError +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x class ClassEmbedder(nn.Module):