diff --git a/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024-dev.yaml b/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024-dev.yaml index 7ca45b8..2d98cc7 100644 --- a/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024-dev.yaml +++ b/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024-dev.yaml @@ -22,10 +22,11 @@ model: low_scale_config: target: ldm.modules.encoders.modules.LowScaleEncoder params: + scale_factor: 0.18215 linear_start: 0.00085 linear_end: 0.0120 timesteps: 1000 - max_noise_level: 250 + max_noise_level: 100 output_size: 64 model_config: target: ldm.models.autoencoder.AutoencoderKL @@ -160,13 +161,14 @@ lightning: increase_log_steps: False log_first_step: False log_images_kwargs: + sample: False use_ema_scope: False inpaint: False plot_progressive_rows: False plot_diffusion_rows: False N: 4 - unconditional_guidance_scale: 3.0 - unconditional_guidance_label: [""] + #unconditional_guidance_scale: 3.0 + #unconditional_guidance_label: [""] trainer: benchmark: True diff --git a/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml b/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml index 9a45eed..e6b9db5 100644 --- a/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml +++ b/configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml @@ -20,6 +20,7 @@ model: low_scale_config: target: ldm.modules.encoders.modules.LowScaleEncoder params: + scale_factor: 0.18215 linear_start: 0.00085 linear_end: 0.0120 timesteps: 1000 diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 7d6cb48..3d01f60 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -80,9 +80,12 @@ class DDIMSampler(object): ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 8846cab..6aa8b86 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -1278,10 +1278,10 @@ class LatentDiffusion(DDPM): xc = self.cond_stage_model.decode(c) log["conditioning"] = xc elif self.cond_stage_key in ["caption", "txt"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key]) + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) log["conditioning"] = xc elif self.cond_stage_key == 'class_label': - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) log['conditioning'] = xc elif isimage(xc): log["conditioning"] = xc @@ -1463,16 +1463,105 @@ class LatentUpscaleDiffusion(LatentDiffusion): param.requires_grad = False @torch.no_grad() - def get_input(self, batch, k, cond_key=None, bs=None): - z, c, x = super().get_input(batch, k, return_x=True, force_c_encode=True, bs=bs) - x_low = batch[self.low_scale_key] + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] x_low = rearrange(x_low, 'b h w c -> b c h w') x_low = x_low.to(memory_format=torch.contiguous_format).float() zx, noise_level = self.low_scale_model(x_low) all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + if log_mode: + # TODO: maybe disable if too expensive + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level return z, all_conds - # TODO log it + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log class Layout2ImgDiffusion(LatentDiffusion): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index a87f6a0..21ba5b4 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -208,13 +208,15 @@ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_t class LowScaleEncoder(nn.Module): - def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64): + def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64, + scale_factor=1.0): super().__init__() self.max_noise_level = max_noise_level self.model = instantiate_from_config(model_config) self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start, linear_end=linear_end) self.out_size = output_size + self.scale_factor = scale_factor def register_schedule(self, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): @@ -250,12 +252,17 @@ class LowScaleEncoder(nn.Module): def forward(self, x): z = self.model.encode(x).sample() + z = z * self.scale_factor noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() z = self.q_sample(z, noise_level) #z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) return z, noise_level + def decode(self, z): + z = z / self.scale_factor + return self.model.decode(z) + if __name__ == "__main__": from ldm.util import count_params