diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 1e495ee..99a08c0 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -7,6 +7,7 @@ from functools import partial from einops import rearrange from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding class DDIMSampler(object): @@ -216,30 +217,7 @@ class DDIMSampler(object): pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) if dynamic_threshold is not None: - # renorm - pred_max = pred_x0.max() - pred_min = pred_x0.min() - pred_x0 = (pred_x0-pred_min)/(pred_max-pred_min) # 0 ... 1 - pred_x0 = 2*pred_x0 - 1. # -1 ... 1 - - s = torch.quantile( - rearrange(pred_x0, 'b ... -> b (...)').abs(), - dynamic_threshold, - dim=-1 - ) - s.clamp_(min=1.0) - s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) - - # clip by threshold - #pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max - - # temporary hack: numpy on cpu - pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() - pred_x0 = torch.tensor(pred_x0).to(self.model.device) - - # re.renorm - pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 - pred_x0 = (pred_max-pred_min)*pred_x0 + pred_min # orig range + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 78eeb10..7002a36 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -6,6 +6,7 @@ from tqdm import tqdm from functools import partial from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding class PLMSSampler(object): @@ -77,6 +78,7 @@ class PLMSSampler(object): unconditional_guidance_scale=1., unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, **kwargs ): if conditioning is not None: @@ -108,6 +110,7 @@ class PLMSSampler(object): log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, ) return samples, intermediates @@ -117,7 +120,8 @@ class PLMSSampler(object): callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -155,7 +159,8 @@ class PLMSSampler(object): corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next) + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: @@ -172,7 +177,8 @@ class PLMSSampler(object): @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): b, *_, device = *x.shape, x.device def get_model_output(x, t): @@ -207,6 +213,8 @@ class PLMSSampler(object): pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py new file mode 100644 index 0000000..a0ae00f --- /dev/null +++ b/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,50 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def renorm_thresholding(x0, value): + # renorm + pred_max = x0.max() + pred_min = x0.min() + pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 + pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 + + s = torch.quantile( + rearrange(pred_x0, 'b ... -> b (...)').abs(), + value, + dim=-1 + ) + s.clamp_(min=1.0) + s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) + + # clip by threshold + # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max + + # temporary hack: numpy on cpu + pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() + pred_x0 = torch.tensor(pred_x0).to(self.model.device) + + # re.renorm + pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 + pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range + return pred_x0 + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/scripts/prompts/six-prompts b/scripts/prompts/six-prompts new file mode 100644 index 0000000..6a041dc --- /dev/null +++ b/scripts/prompts/six-prompts @@ -0,0 +1,6 @@ +the Tower of Babel by J.M.W. Turner +advertisement for a psychedelic virtual reality headset, 16 bit sprite pixel art +the gateway between dreams, trending on ArtStation +Humanity is killed by AI, by James Gurney +A fantasy painting of a city in a deep valley by Ivan Aivazovsky +Darth Vader at Woodstock (1969) diff --git a/scripts/prompts/weird-dalle-prompts.txt b/scripts/prompts/weird-dalle-prompts.txt index 39ebf04..74b393d 100644 --- a/scripts/prompts/weird-dalle-prompts.txt +++ b/scripts/prompts/weird-dalle-prompts.txt @@ -4,7 +4,7 @@ Bunny Vikings The Demogorgon from Stranger Thinhs holding a basketball Hamster in my microwave a courtroom sketch of a Ford Transit van -PS1 Hagrid ad MCDonalds +PS1 Hagrid at MCDonalds Karl Marx in KFC Logo Moai Statue giving a TED talk wahing machine trail cam diff --git a/scripts/txt2img.py b/scripts/txt2img.py index e161501..cbf8525 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -8,6 +8,7 @@ from itertools import islice from einops import rearrange from torchvision.utils import make_grid import time +from pytorch_lightning import seed_everything from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler @@ -167,8 +168,14 @@ if __name__ == "__main__": default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt", help="path to checkpoint of model", ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) opt = parser.parse_args() - + seed_everything(opt.seed) config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") @@ -205,8 +212,8 @@ if __name__ == "__main__": with torch.no_grad(): with model.ema_scope(): tic = time.time() + all_samples = list() for n in trange(opt.n_iter, desc="Sampling"): - all_samples = list() for prompts in tqdm(data, desc="data"): uc = None if opt.scale != 1.0: @@ -235,16 +242,16 @@ if __name__ == "__main__": base_count += 1 all_samples.append(x_samples_ddim) - if not opt.skip_grid: - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 toc = time.time()