From 2cbbb3148a7eb070c8e618a89045be7577e54133 Mon Sep 17 00:00:00 2001 From: Patrick Esser Date: Wed, 1 Jun 2022 08:46:11 +0000 Subject: [PATCH] optional use of ema --- ldm/models/diffusion/ddpm.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 49b8ecc..2d218bc 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -12,7 +12,7 @@ import numpy as np import pytorch_lightning as pl from torch.optim.lr_scheduler import LambdaLR from einops import rearrange, repeat -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from functools import partial from tqdm import tqdm from torchvision.utils import make_grid @@ -1190,7 +1190,8 @@ class LatentDiffusion(DDPM): @torch.no_grad() def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, **kwargs): + plot_diffusion_rows=True, use_ema_scope=True, **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext use_ddim = ddim_steps is not None @@ -1239,7 +1240,7 @@ class LatentDiffusion(DDPM): if sample: # get denoise row - with self.ema_scope("Plotting"): + with ema_scope("Plotting"): samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, ddim_steps=ddim_steps,eta=ddim_eta) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) @@ -1252,7 +1253,7 @@ class LatentDiffusion(DDPM): if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( self.first_stage_model, IdentityFirstStage): # also display when quantizing x0 while sampling - with self.ema_scope("Plotting Quantized Denoised"): + with ema_scope("Plotting Quantized Denoised"): samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, ddim_steps=ddim_steps,eta=ddim_eta, quantize_denoised=True) @@ -1268,7 +1269,7 @@ class LatentDiffusion(DDPM): # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask = mask[:, None, ...] - with self.ema_scope("Plotting Inpaint"): + with ema_scope("Plotting Inpaint"): samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask) @@ -1277,14 +1278,14 @@ class LatentDiffusion(DDPM): log["mask"] = mask # outpaint - with self.ema_scope("Plotting Outpaint"): + with ema_scope("Plotting Outpaint"): samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_outpainting"] = x_samples if plot_progressive_rows: - with self.ema_scope("Plotting Progressives"): + with ema_scope("Plotting Progressives"): img, progressives = self.progressive_denoising(c, shape=(self.channels, self.image_size, self.image_size), batch_size=N)