Merge remote-tracking branch 'github/main' into main
This commit is contained in:
commit
3769989f20
2 changed files with 45 additions and 11 deletions
|
@ -110,11 +110,17 @@ lightning:
|
||||||
image_logger:
|
image_logger:
|
||||||
target: main.ImageLogger
|
target: main.ImageLogger
|
||||||
params:
|
params:
|
||||||
batch_frequency: 5000
|
batch_frequency: 50
|
||||||
max_images: 8
|
max_images: 4
|
||||||
increase_log_steps: False
|
increase_log_steps: False
|
||||||
log_first_step: False
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: ""
|
||||||
|
|
||||||
trainer:
|
trainer:
|
||||||
#replace_sampler_ddp: False
|
#replace_sampler_ddp: False
|
||||||
|
|
|
@ -17,6 +17,7 @@ from functools import partial
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
|
from omegaconf import ListConfig
|
||||||
|
|
||||||
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
|
||||||
from ldm.modules.ema import LitEma
|
from ldm.modules.ema import LitEma
|
||||||
|
@ -674,7 +675,6 @@ class LatentDiffusion(DDPM):
|
||||||
xc = x
|
xc = x
|
||||||
if not self.cond_stage_trainable or force_c_encode:
|
if not self.cond_stage_trainable or force_c_encode:
|
||||||
if isinstance(xc, dict) or isinstance(xc, list):
|
if isinstance(xc, dict) or isinstance(xc, list):
|
||||||
# import pudb; pudb.set_trace()
|
|
||||||
c = self.get_learned_conditioning(xc)
|
c = self.get_learned_conditioning(xc)
|
||||||
else:
|
else:
|
||||||
c = self.get_learned_conditioning(xc.to(self.device))
|
c = self.get_learned_conditioning(xc.to(self.device))
|
||||||
|
@ -1172,25 +1172,41 @@ class LatentDiffusion(DDPM):
|
||||||
mask=mask, x0=x0)
|
mask=mask, x0=x0)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
|
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
||||||
|
|
||||||
if ddim:
|
if ddim:
|
||||||
ddim_sampler = DDIMSampler(self)
|
ddim_sampler = DDIMSampler(self)
|
||||||
shape = (self.channels, self.image_size, self.image_size)
|
shape = (self.channels, self.image_size, self.image_size)
|
||||||
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
|
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
|
||||||
shape,cond,verbose=False,**kwargs)
|
shape, cond, verbose=False, **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
|
||||||
return_intermediates=True,**kwargs)
|
return_intermediates=True, **kwargs)
|
||||||
|
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||||
|
if null_label is not None:
|
||||||
|
xc = null_label
|
||||||
|
if isinstance(xc, ListConfig):
|
||||||
|
xc = list(xc)
|
||||||
|
if isinstance(xc, dict) or isinstance(xc, list):
|
||||||
|
c = self.get_learned_conditioning(xc)
|
||||||
|
else:
|
||||||
|
c = self.get_learned_conditioning(xc.to(self.device))
|
||||||
|
else:
|
||||||
|
# todo: get null label from cond_stage_model
|
||||||
|
raise NotImplementedError()
|
||||||
|
c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
|
||||||
|
return c
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
||||||
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
||||||
plot_diffusion_rows=True, use_ema_scope=True, **kwargs):
|
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
|
||||||
|
use_ema_scope=True,
|
||||||
|
**kwargs):
|
||||||
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
||||||
|
|
||||||
use_ddim = ddim_steps is not None
|
use_ddim = ddim_steps is not None
|
||||||
|
@ -1240,7 +1256,7 @@ class LatentDiffusion(DDPM):
|
||||||
|
|
||||||
if sample:
|
if sample:
|
||||||
# get denoise row
|
# get denoise row
|
||||||
with ema_scope("Plotting"):
|
with ema_scope("Sampling"):
|
||||||
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
|
||||||
ddim_steps=ddim_steps,eta=ddim_eta)
|
ddim_steps=ddim_steps,eta=ddim_eta)
|
||||||
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||||
|
@ -1262,6 +1278,17 @@ class LatentDiffusion(DDPM):
|
||||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||||
log["samples_x0_quantized"] = x_samples
|
log["samples_x0_quantized"] = x_samples
|
||||||
|
|
||||||
|
if unconditional_guidance_scale > 1.0:
|
||||||
|
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||||
|
with ema_scope("Sampling with classifier-free guidance"):
|
||||||
|
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||||
|
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
)
|
||||||
|
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||||
|
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
||||||
|
@ -1278,6 +1305,7 @@ class LatentDiffusion(DDPM):
|
||||||
log["mask"] = mask
|
log["mask"] = mask
|
||||||
|
|
||||||
# outpaint
|
# outpaint
|
||||||
|
mask = 1. - mask
|
||||||
with ema_scope("Plotting Outpaint"):
|
with ema_scope("Plotting Outpaint"):
|
||||||
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
|
||||||
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
|
||||||
|
|
Loading…
Reference in a new issue