sampling can be addictive
This commit is contained in:
parent
3809ba7046
commit
181d1ad8f2
6 changed files with 88 additions and 39 deletions
|
@ -7,6 +7,7 @@ from functools import partial
|
|||
from einops import rearrange
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
|
@ -216,30 +217,7 @@ class DDIMSampler(object):
|
|||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
# renorm
|
||||
pred_max = pred_x0.max()
|
||||
pred_min = pred_x0.min()
|
||||
pred_x0 = (pred_x0-pred_min)/(pred_max-pred_min) # 0 ... 1
|
||||
pred_x0 = 2*pred_x0 - 1. # -1 ... 1
|
||||
|
||||
s = torch.quantile(
|
||||
rearrange(pred_x0, 'b ... -> b (...)').abs(),
|
||||
dynamic_threshold,
|
||||
dim=-1
|
||||
)
|
||||
s.clamp_(min=1.0)
|
||||
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
|
||||
|
||||
# clip by threshold
|
||||
#pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
||||
|
||||
# temporary hack: numpy on cpu
|
||||
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
|
||||
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
||||
|
||||
# re.renorm
|
||||
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
|
||||
pred_x0 = (pred_max-pred_min)*pred_x0 + pred_min # orig range
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
|
|
@ -6,6 +6,7 @@ from tqdm import tqdm
|
|||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
|
@ -77,6 +78,7 @@ class PLMSSampler(object):
|
|||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
|
@ -108,6 +110,7 @@ class PLMSSampler(object):
|
|||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
|
@ -117,7 +120,8 @@ class PLMSSampler(object):
|
|||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
|
@ -155,7 +159,8 @@ class PLMSSampler(object):
|
|||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next)
|
||||
old_eps=old_eps, t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
|
@ -172,7 +177,8 @@ class PLMSSampler(object):
|
|||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
|
@ -207,6 +213,8 @@ class PLMSSampler(object):
|
|||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
|
|
50
ldm/models/diffusion/sampling_util.py
Normal file
50
ldm/models/diffusion/sampling_util.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def renorm_thresholding(x0, value):
|
||||
# renorm
|
||||
pred_max = x0.max()
|
||||
pred_min = x0.min()
|
||||
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
|
||||
pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
|
||||
|
||||
s = torch.quantile(
|
||||
rearrange(pred_x0, 'b ... -> b (...)').abs(),
|
||||
value,
|
||||
dim=-1
|
||||
)
|
||||
s.clamp_(min=1.0)
|
||||
s = s.view(-1, *((1,) * (pred_x0.ndim - 1)))
|
||||
|
||||
# clip by threshold
|
||||
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
||||
|
||||
# temporary hack: numpy on cpu
|
||||
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy()
|
||||
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
||||
|
||||
# re.renorm
|
||||
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
|
||||
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
|
||||
return pred_x0
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
6
scripts/prompts/six-prompts
Normal file
6
scripts/prompts/six-prompts
Normal file
|
@ -0,0 +1,6 @@
|
|||
the Tower of Babel by J.M.W. Turner
|
||||
advertisement for a psychedelic virtual reality headset, 16 bit sprite pixel art
|
||||
the gateway between dreams, trending on ArtStation
|
||||
Humanity is killed by AI, by James Gurney
|
||||
A fantasy painting of a city in a deep valley by Ivan Aivazovsky
|
||||
Darth Vader at Woodstock (1969)
|
|
@ -4,7 +4,7 @@ Bunny Vikings
|
|||
The Demogorgon from Stranger Thinhs holding a basketball
|
||||
Hamster in my microwave
|
||||
a courtroom sketch of a Ford Transit van
|
||||
PS1 Hagrid ad MCDonalds
|
||||
PS1 Hagrid at MCDonalds
|
||||
Karl Marx in KFC Logo
|
||||
Moai Statue giving a TED talk
|
||||
wahing machine trail cam
|
||||
|
|
|
@ -8,6 +8,7 @@ from itertools import islice
|
|||
from einops import rearrange
|
||||
from torchvision.utils import make_grid
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
@ -167,8 +168,14 @@ if __name__ == "__main__":
|
|||
default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt",
|
||||
help="path to checkpoint of model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="the seed (for reproducible sampling)",
|
||||
)
|
||||
opt = parser.parse_args()
|
||||
|
||||
seed_everything(opt.seed)
|
||||
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
@ -205,8 +212,8 @@ if __name__ == "__main__":
|
|||
with torch.no_grad():
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
all_samples = list()
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
|
@ -235,16 +242,16 @@ if __name__ == "__main__":
|
|||
base_count += 1
|
||||
all_samples.append(x_samples_ddim)
|
||||
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||
grid_count += 1
|
||||
|
||||
toc = time.time()
|
||||
|
||||
|
|
Loading…
Reference in a new issue