optional use of ema
This commit is contained in:
parent
abd9fb7af1
commit
2cbbb3148a
1 changed files with 8 additions and 7 deletions
|
@ -12,7 +12,7 @@ import numpy as np
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from torch.optim.lr_scheduler import LambdaLR
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
|
@ -1190,7 +1190,8 @@ class LatentDiffusion(DDPM):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
||||||
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
||||||
plot_diffusion_rows=True, **kwargs):
|
plot_diffusion_rows=True, use_ema_scope=True, **kwargs):
|
||||||
|
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
||||||
|
|
||||||
use_ddim = ddim_steps is not None
|
use_ddim = ddim_steps is not None
|
||||||
|
|
||||||
|
@ -1239,7 +1240,7 @@ class LatentDiffusion(DDPM):
|
||||||
|
|
||||||
if sample:
|
if sample:
|
||||||
# get denoise row
|
# get denoise row
|
||||||
with self.ema_scope("Plotting"):
|
with ema_scope("Plotting"):
|
||||||
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
||||||
ddim_steps=ddim_steps,eta=ddim_eta)
|
ddim_steps=ddim_steps,eta=ddim_eta)
|
||||||
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||||
|
@ -1252,7 +1253,7 @@ class LatentDiffusion(DDPM):
|
||||||
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
|
||||||
self.first_stage_model, IdentityFirstStage):
|
self.first_stage_model, IdentityFirstStage):
|
||||||
# also display when quantizing x0 while sampling
|
# also display when quantizing x0 while sampling
|
||||||
with self.ema_scope("Plotting Quantized Denoised"):
|
with ema_scope("Plotting Quantized Denoised"):
|
||||||
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
||||||
ddim_steps=ddim_steps,eta=ddim_eta,
|
ddim_steps=ddim_steps,eta=ddim_eta,
|
||||||
quantize_denoised=True)
|
quantize_denoised=True)
|
||||||
|
@ -1268,7 +1269,7 @@ class LatentDiffusion(DDPM):
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
mask = mask[:, None, ...]
|
mask = mask[:, None, ...]
|
||||||
with self.ema_scope("Plotting Inpaint"):
|
with ema_scope("Plotting Inpaint"):
|
||||||
|
|
||||||
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
|
||||||
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
||||||
|
@ -1277,14 +1278,14 @@ class LatentDiffusion(DDPM):
|
||||||
log["mask"] = mask
|
log["mask"] = mask
|
||||||
|
|
||||||
# outpaint
|
# outpaint
|
||||||
with self.ema_scope("Plotting Outpaint"):
|
with ema_scope("Plotting Outpaint"):
|
||||||
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
||||||
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
||||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||||
log["samples_outpainting"] = x_samples
|
log["samples_outpainting"] = x_samples
|
||||||
|
|
||||||
if plot_progressive_rows:
|
if plot_progressive_rows:
|
||||||
with self.ema_scope("Plotting Progressives"):
|
with ema_scope("Plotting Progressives"):
|
||||||
img, progressives = self.progressive_denoising(c,
|
img, progressives = self.progressive_denoising(c,
|
||||||
shape=(self.channels, self.image_size, self.image_size),
|
shape=(self.channels, self.image_size, self.image_size),
|
||||||
batch_size=N)
|
batch_size=N)
|
||||||
|
|
Loading…
Reference in a new issue