optional use of ema

This commit is contained in:
Patrick Esser 2022-06-01 08:46:11 +00:00 committed by root
parent abd9fb7af1
commit 2cbbb3148a
1 changed files with 8 additions and 7 deletions

View File

@ -12,7 +12,7 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat from einops import rearrange, repeat
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from functools import partial from functools import partial
from tqdm import tqdm from tqdm import tqdm
from torchvision.utils import make_grid from torchvision.utils import make_grid
@ -1190,7 +1190,8 @@ class LatentDiffusion(DDPM):
@torch.no_grad() @torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, **kwargs): plot_diffusion_rows=True, use_ema_scope=True, **kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None use_ddim = ddim_steps is not None
@ -1239,7 +1240,7 @@ class LatentDiffusion(DDPM):
if sample: if sample:
# get denoise row # get denoise row
with self.ema_scope("Plotting"): with ema_scope("Plotting"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta) ddim_steps=ddim_steps,eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
@ -1252,7 +1253,7 @@ class LatentDiffusion(DDPM):
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage): self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling # also display when quantizing x0 while sampling
with self.ema_scope("Plotting Quantized Denoised"): with ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta, ddim_steps=ddim_steps,eta=ddim_eta,
quantize_denoised=True) quantize_denoised=True)
@ -1268,7 +1269,7 @@ class LatentDiffusion(DDPM):
# zeros will be filled in # zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...] mask = mask[:, None, ...]
with self.ema_scope("Plotting Inpaint"): with ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask) ddim_steps=ddim_steps, x0=z[:N], mask=mask)
@ -1277,14 +1278,14 @@ class LatentDiffusion(DDPM):
log["mask"] = mask log["mask"] = mask
# outpaint # outpaint
with self.ema_scope("Plotting Outpaint"): with ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask) ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device)) x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples log["samples_outpainting"] = x_samples
if plot_progressive_rows: if plot_progressive_rows:
with self.ema_scope("Plotting Progressives"): with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c, img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size), shape=(self.channels, self.image_size, self.image_size),
batch_size=N) batch_size=N)