sampling can be addictive
This commit is contained in:
parent
3809ba7046
commit
181d1ad8f2
6 changed files with 88 additions and 39 deletions
|
@ -7,6 +7,7 @@ from functools import partial
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||||
|
from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
class DDIMSampler(object):
|
||||||
|
@ -216,30 +217,7 @@ class DDIMSampler(object):
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
|
||||||
if dynamic_threshold is not None:
|
if dynamic_threshold is not None:
|
||||||
# renorm
|
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||||
pred_max = pred_x0.max()
|
|
||||||
pred_min = pred_x0.min()
|
|
||||||
pred_x0 = (pred_x0-pred_min)/(pred_max-pred_min) # 0 ... 1
|
|
||||||
pred_x0 = 2*pred_x0 - 1. # -1 ... 1
|
|
||||||
|
|
||||||
s = torch.quantile(
|
|
||||||
rearrange(pred_x0, 'b ... -> b (...)').abs(),
|
|
||||||
dynamic_threshold,
|
|
||||||
dim=-1
|
|
||||||
)
|
|
||||||
s.clamp_(min=1.0)
|
|
||||||
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
|
|
||||||
|
|
||||||
# clip by threshold
|
|
||||||
#pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
|
||||||
|
|
||||||
# temporary hack: numpy on cpu
|
|
||||||
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
|
|
||||||
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
|
||||||
|
|
||||||
# re.renorm
|
|
||||||
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
|
|
||||||
pred_x0 = (pred_max-pred_min)*pred_x0 + pred_min # orig range
|
|
||||||
|
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
|
|
@ -6,6 +6,7 @@ from tqdm import tqdm
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||||
|
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
class PLMSSampler(object):
|
class PLMSSampler(object):
|
||||||
|
@ -77,6 +78,7 @@ class PLMSSampler(object):
|
||||||
unconditional_guidance_scale=1.,
|
unconditional_guidance_scale=1.,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
dynamic_threshold=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
|
@ -108,6 +110,7 @@ class PLMSSampler(object):
|
||||||
log_every_t=log_every_t,
|
log_every_t=log_every_t,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
@ -117,7 +120,8 @@ class PLMSSampler(object):
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
|
@ -155,7 +159,8 @@ class PLMSSampler(object):
|
||||||
corrector_kwargs=corrector_kwargs,
|
corrector_kwargs=corrector_kwargs,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
old_eps=old_eps, t_next=ts_next)
|
old_eps=old_eps, t_next=ts_next,
|
||||||
|
dynamic_threshold=dynamic_threshold)
|
||||||
img, pred_x0, e_t = outs
|
img, pred_x0, e_t = outs
|
||||||
old_eps.append(e_t)
|
old_eps.append(e_t)
|
||||||
if len(old_eps) >= 4:
|
if len(old_eps) >= 4:
|
||||||
|
@ -172,7 +177,8 @@ class PLMSSampler(object):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
||||||
|
dynamic_threshold=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
def get_model_output(x, t):
|
def get_model_output(x, t):
|
||||||
|
@ -207,6 +213,8 @@ class PLMSSampler(object):
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
if quantize_denoised:
|
if quantize_denoised:
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
if dynamic_threshold is not None:
|
||||||
|
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
|
50
ldm/models/diffusion/sampling_util.py
Normal file
50
ldm/models/diffusion/sampling_util.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def append_dims(x, target_dims):
|
||||||
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||||
|
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||||
|
dims_to_append = target_dims - x.ndim
|
||||||
|
if dims_to_append < 0:
|
||||||
|
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||||
|
return x[(...,) + (None,) * dims_to_append]
|
||||||
|
|
||||||
|
|
||||||
|
def renorm_thresholding(x0, value):
|
||||||
|
# renorm
|
||||||
|
pred_max = x0.max()
|
||||||
|
pred_min = x0.min()
|
||||||
|
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
|
||||||
|
pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
|
||||||
|
|
||||||
|
s = torch.quantile(
|
||||||
|
rearrange(pred_x0, 'b ... -> b (...)').abs(),
|
||||||
|
value,
|
||||||
|
dim=-1
|
||||||
|
)
|
||||||
|
s.clamp_(min=1.0)
|
||||||
|
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
|
||||||
|
|
||||||
|
# clip by threshold
|
||||||
|
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
||||||
|
|
||||||
|
# temporary hack: numpy on cpu
|
||||||
|
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
|
||||||
|
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
||||||
|
|
||||||
|
# re.renorm
|
||||||
|
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
|
||||||
|
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
|
||||||
|
return pred_x0
|
||||||
|
|
||||||
|
|
||||||
|
def norm_thresholding(x0, value):
|
||||||
|
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||||
|
return x0 * (value / s)
|
||||||
|
|
||||||
|
|
||||||
|
def spatial_norm_thresholding(x0, value):
|
||||||
|
# b c h w
|
||||||
|
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||||
|
return x0 * (value / s)
|
6
scripts/prompts/six-prompts
Normal file
6
scripts/prompts/six-prompts
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
the Tower of Babel by J.M.W. Turner
|
||||||
|
advertisement for a psychedelic virtual reality headset, 16 bit sprite pixel art
|
||||||
|
the gateway between dreams, trending on ArtStation
|
||||||
|
Humanity is killed by AI, by James Gurney
|
||||||
|
A fantasy painting of a city in a deep valley by Ivan Aivazovsky
|
||||||
|
Darth Vader at Woodstock (1969)
|
|
@ -4,7 +4,7 @@ Bunny Vikings
|
||||||
The Demogorgon from Stranger Thinhs holding a basketball
|
The Demogorgon from Stranger Thinhs holding a basketball
|
||||||
Hamster in my microwave
|
Hamster in my microwave
|
||||||
a courtroom sketch of a Ford Transit van
|
a courtroom sketch of a Ford Transit van
|
||||||
PS1 Hagrid ad MCDonalds
|
PS1 Hagrid at MCDonalds
|
||||||
Karl Marx in KFC Logo
|
Karl Marx in KFC Logo
|
||||||
Moai Statue giving a TED talk
|
Moai Statue giving a TED talk
|
||||||
wahing machine trail cam
|
wahing machine trail cam
|
||||||
|
|
|
@ -8,6 +8,7 @@ from itertools import islice
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
import time
|
import time
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
@ -167,8 +168,14 @@ if __name__ == "__main__":
|
||||||
default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt",
|
default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt",
|
||||||
help="path to checkpoint of model",
|
help="path to checkpoint of model",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="the seed (for reproducible sampling)",
|
||||||
|
)
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
config = OmegaConf.load(f"{opt.config}")
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||||
|
@ -205,8 +212,8 @@ if __name__ == "__main__":
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
all_samples = list()
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
all_samples = list()
|
|
||||||
for prompts in tqdm(data, desc="data"):
|
for prompts in tqdm(data, desc="data"):
|
||||||
uc = None
|
uc = None
|
||||||
if opt.scale != 1.0:
|
if opt.scale != 1.0:
|
||||||
|
@ -235,16 +242,16 @@ if __name__ == "__main__":
|
||||||
base_count += 1
|
base_count += 1
|
||||||
all_samples.append(x_samples_ddim)
|
all_samples.append(x_samples_ddim)
|
||||||
|
|
||||||
if not opt.skip_grid:
|
if not opt.skip_grid:
|
||||||
# additionally, save as grid
|
# additionally, save as grid
|
||||||
grid = torch.stack(all_samples, 0)
|
grid = torch.stack(all_samples, 0)
|
||||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
grid = make_grid(grid, nrow=n_rows)
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
# to image
|
# to image
|
||||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
grid_count += 1
|
grid_count += 1
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue