add cfg to log_images

This commit is contained in:
rromb 2022-06-01 09:36:48 +02:00
parent 634a591701
commit b3a604d870

View file

@ -674,7 +674,6 @@ class LatentDiffusion(DDPM):
xc = x xc = x
if not self.cond_stage_trainable or force_c_encode: if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list): if isinstance(xc, dict) or isinstance(xc, list):
# import pudb; pudb.set_trace()
c = self.get_learned_conditioning(xc) c = self.get_learned_conditioning(xc)
else: else:
c = self.get_learned_conditioning(xc.to(self.device)) c = self.get_learned_conditioning(xc.to(self.device))
@ -1172,25 +1171,38 @@ class LatentDiffusion(DDPM):
mask=mask, x0=x0) mask=mask, x0=x0)
@torch.no_grad() @torch.no_grad()
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
if ddim: if ddim:
ddim_sampler = DDIMSampler(self) ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size) shape = (self.channels, self.image_size, self.image_size)
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
shape,cond,verbose=False,**kwargs) shape, cond, verbose=False, **kwargs)
else: else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size, samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
return_intermediates=True,**kwargs) return_intermediates=True, **kwargs)
return samples, intermediates return samples, intermediates
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
return c
@torch.no_grad() @torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, **kwargs): plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
**kwargs):
use_ddim = ddim_steps is not None use_ddim = ddim_steps is not None
@ -1239,7 +1251,7 @@ class LatentDiffusion(DDPM):
if sample: if sample:
# get denoise row # get denoise row
with self.ema_scope("Plotting"): with self.ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta) ddim_steps=ddim_steps,eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
@ -1261,6 +1273,17 @@ class LatentDiffusion(DDPM):
x_samples = self.decode_first_stage(samples.to(self.device)) x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples log["samples_x0_quantized"] = x_samples
if unconditional_guidance_scale > 1.0:
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
with self.ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if inpaint: if inpaint:
# make a simple center square # make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3] b, h, w = z.shape[0], z.shape[2], z.shape[3]
@ -1277,6 +1300,7 @@ class LatentDiffusion(DDPM):
log["mask"] = mask log["mask"] = mask
# outpaint # outpaint
mask = 1. - mask
with self.ema_scope("Plotting Outpaint"): with self.ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask) ddim_steps=ddim_steps, x0=z[:N], mask=mask)