dynamic thresholding (maybe?)
This commit is contained in:
parent
64cc718ead
commit
a0f390afb6
2 changed files with 40 additions and 5 deletions
|
@ -4,6 +4,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||||
|
|
||||||
|
@ -73,8 +74,8 @@ class DDIMSampler(object):
|
||||||
x_T=None,
|
x_T=None,
|
||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
unconditional_guidance_scale=1.,
|
unconditional_guidance_scale=1.,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
dynamic_threshold=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
|
@ -106,6 +107,7 @@ class DDIMSampler(object):
|
||||||
log_every_t=log_every_t,
|
log_every_t=log_every_t,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
@ -115,7 +117,7 @@ class DDIMSampler(object):
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
|
@ -150,7 +152,8 @@ class DDIMSampler(object):
|
||||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
corrector_kwargs=corrector_kwargs,
|
corrector_kwargs=corrector_kwargs,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold)
|
||||||
img, pred_x0 = outs
|
img, pred_x0 = outs
|
||||||
if callback: callback(i)
|
if callback: callback(i)
|
||||||
if img_callback: img_callback(pred_x0, i)
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
@ -164,7 +167,8 @@ class DDIMSampler(object):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
@ -194,6 +198,33 @@ class DDIMSampler(object):
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
if quantize_denoised:
|
if quantize_denoised:
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
|
||||||
|
if dynamic_threshold is not None:
|
||||||
|
# renorm
|
||||||
|
pred_max = pred_x0.max()
|
||||||
|
pred_min = pred_x0.min()
|
||||||
|
pred_x0 = (pred_x0-pred_min)/(pred_max-pred_min) # 0 ... 1
|
||||||
|
pred_x0 = 2*pred_x0 - 1. # -1 ... 1
|
||||||
|
|
||||||
|
s = torch.quantile(
|
||||||
|
rearrange(pred_x0, 'b ... -> b (...)').abs(),
|
||||||
|
dynamic_threshold,
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
s.clamp_(min=1.0)
|
||||||
|
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
|
||||||
|
|
||||||
|
# clip by threshold
|
||||||
|
#pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
||||||
|
|
||||||
|
# temporary hack: numpy on cpu
|
||||||
|
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
|
||||||
|
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
||||||
|
|
||||||
|
# re.renorm
|
||||||
|
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
|
||||||
|
pred_x0 = (pred_max-pred_min)*pred_x0 + pred_min # orig range
|
||||||
|
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
|
|
@ -12,6 +12,10 @@ class AbstractEncoder(nn.Module):
|
||||||
def encode(self, *args, **kwargs):
|
def encode(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class IdentityEncoder(AbstractEncoder):
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ClassEmbedder(nn.Module):
|
class ClassEmbedder(nn.Module):
|
||||||
|
|
Loading…
Reference in a new issue