optional use of ema

This commit is contained in:
Patrick Esser 2022-06-01 08:46:11 +00:00 committed by root
parent abd9fb7af1
commit 2cbbb3148a
1 changed files with 8 additions and 7 deletions

View File

@ -12,7 +12,7 @@ import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
@ -1190,7 +1190,8 @@ class LatentDiffusion(DDPM):
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, **kwargs):
plot_diffusion_rows=True, use_ema_scope=True, **kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
@ -1239,7 +1240,7 @@ class LatentDiffusion(DDPM):
if sample:
# get denoise row
with self.ema_scope("Plotting"):
with ema_scope("Plotting"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
@ -1252,7 +1253,7 @@ class LatentDiffusion(DDPM):
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with self.ema_scope("Plotting Quantized Denoised"):
with ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta,
quantize_denoised=True)
@ -1268,7 +1269,7 @@ class LatentDiffusion(DDPM):
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with self.ema_scope("Plotting Inpaint"):
with ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
@ -1277,14 +1278,14 @@ class LatentDiffusion(DDPM):
log["mask"] = mask
# outpaint
with self.ema_scope("Plotting Outpaint"):
with ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
if plot_progressive_rows:
with self.ema_scope("Plotting Progressives"):
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)