From 634a591701bd14d4de496b2075ed8a436b343002 Mon Sep 17 00:00:00 2001 From: rromb Date: Wed, 1 Jun 2022 09:36:27 +0200 Subject: [PATCH 1/4] adapted sampling config for clip-encoder training --- .../txt2img-1p4B-multinode-clip-encoder.yaml | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder.yaml b/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder.yaml index da87af4..22db311 100644 --- a/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder.yaml +++ b/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder.yaml @@ -110,11 +110,16 @@ lightning: image_logger: target: main.ImageLogger params: - batch_frequency: 5000 - max_images: 8 + batch_frequency: 50 + max_images: 4 increase_log_steps: False log_first_step: False - + log_images_kwargs: + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: "" trainer: #replace_sampler_ddp: False From b3a604d870319bfc87e2949154b8c340b9293ee0 Mon Sep 17 00:00:00 2001 From: rromb Date: Wed, 1 Jun 2022 09:36:48 +0200 Subject: [PATCH 2/4] add cfg to log_images --- ldm/models/diffusion/ddpm.py | 40 ++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 49b8ecc..4881913 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -674,7 +674,6 @@ class LatentDiffusion(DDPM): xc = x if not self.cond_stage_trainable or force_c_encode: if isinstance(xc, dict) or isinstance(xc, list): - # import pudb; pudb.set_trace() c = self.get_learned_conditioning(xc) else: c = self.get_learned_conditioning(xc.to(self.device)) @@ -1172,25 +1171,38 @@ class LatentDiffusion(DDPM): mask=mask, x0=x0) @torch.no_grad() - def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): - + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): if ddim: ddim_sampler = DDIMSampler(self) shape = (self.channels, self.image_size, self.image_size) - samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, - shape,cond,verbose=False,**kwargs) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) else: samples, intermediates = self.sample(cond=cond, batch_size=batch_size, - return_intermediates=True,**kwargs) + return_intermediates=True, **kwargs) return samples, intermediates + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + return c @torch.no_grad() def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, **kwargs): + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + **kwargs): use_ddim = ddim_steps is not None @@ -1239,7 +1251,7 @@ class LatentDiffusion(DDPM): if sample: # get denoise row - with self.ema_scope("Plotting"): + with self.ema_scope("Sampling"): samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, ddim_steps=ddim_steps,eta=ddim_eta) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) @@ -1261,6 +1273,17 @@ class LatentDiffusion(DDPM): x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_x0_quantized"] = x_samples + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + with self.ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + if inpaint: # make a simple center square b, h, w = z.shape[0], z.shape[2], z.shape[3] @@ -1277,6 +1300,7 @@ class LatentDiffusion(DDPM): log["mask"] = mask # outpaint + mask = 1. - mask with self.ema_scope("Plotting Outpaint"): samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask) From fff19bf82ec55946cde7a28374d6cfa5773e3e72 Mon Sep 17 00:00:00 2001 From: rromb Date: Wed, 1 Jun 2022 09:52:17 +0200 Subject: [PATCH 3/4] handle listconfig --- ldm/models/diffusion/ddpm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 4881913..06a04bd 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -17,6 +17,7 @@ from functools import partial from tqdm import tqdm from torchvision.utils import make_grid from pytorch_lightning.utilities.distributed import rank_zero_only +from omegaconf import ListConfig from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.modules.ema import LitEma @@ -1188,6 +1189,8 @@ class LatentDiffusion(DDPM): def get_unconditional_conditioning(self, batch_size, null_label=None): if null_label is not None: xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) if isinstance(xc, dict) or isinstance(xc, list): c = self.get_learned_conditioning(xc) else: From 8bb094a0ec705dad861624f3da49e5cafa81d78b Mon Sep 17 00:00:00 2001 From: rromb Date: Wed, 1 Jun 2022 10:10:19 +0200 Subject: [PATCH 4/4] final nerfed sampling config --- .../stable-diffusion/txt2img-1p4B-multinode-clip-encoder.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder.yaml b/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder.yaml index 22db311..ca88ebd 100644 --- a/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder.yaml +++ b/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder.yaml @@ -115,6 +115,7 @@ lightning: increase_log_steps: False log_first_step: False log_images_kwargs: + inpaint: False plot_progressive_rows: False plot_diffusion_rows: False N: 4