experimental support for guided upscale sampling
This commit is contained in:
parent
255d479088
commit
3c77204c0a
3 changed files with 27 additions and 3 deletions
|
@ -24,7 +24,7 @@ model:
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
timesteps: 1000
|
timesteps: 1000
|
||||||
max_noise_level: 250
|
max_noise_level: 100
|
||||||
output_size: 64
|
output_size: 64
|
||||||
model_config:
|
model_config:
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
|
|
@ -179,7 +179,20 @@ class DDIMSampler(object):
|
||||||
else:
|
else:
|
||||||
x_in = torch.cat([x] * 2)
|
x_in = torch.cat([x] * 2)
|
||||||
t_in = torch.cat([t] * 2)
|
t_in = torch.cat([t] * 2)
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
if isinstance(c, dict):
|
||||||
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
|
c_in = dict()
|
||||||
|
for k in c:
|
||||||
|
if isinstance(c[k], list):
|
||||||
|
c_in[k] = [torch.cat([
|
||||||
|
unconditional_conditioning[k][i],
|
||||||
|
c[k][i]]) for i in range(len(c[k]))]
|
||||||
|
else:
|
||||||
|
c_in[k] = torch.cat([
|
||||||
|
unconditional_conditioning[k],
|
||||||
|
c[k]])
|
||||||
|
else:
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
|
|
@ -1543,7 +1543,18 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
||||||
log["denoise_row"] = denoise_grid
|
log["denoise_row"] = denoise_grid
|
||||||
|
|
||||||
if unconditional_guidance_scale > 1.0:
|
if unconditional_guidance_scale > 1.0:
|
||||||
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||||
|
# TODO explore better "unconditional" choices for the other keys
|
||||||
|
uc = dict()
|
||||||
|
for k in c:
|
||||||
|
if k == "c_crossattn":
|
||||||
|
assert isinstance(c[k], list) and len(c[k]) == 1
|
||||||
|
uc[k] = [uc_tmp]
|
||||||
|
elif isinstance(c[k], list):
|
||||||
|
uc[k] = [torch.zeros_like(c[k][i]) for i in range(len(c[k]))]
|
||||||
|
else:
|
||||||
|
uc[k] = torch.zeros_like(c[k])
|
||||||
|
|
||||||
with ema_scope("Sampling with classifier-free guidance"):
|
with ema_scope("Sampling with classifier-free guidance"):
|
||||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||||
ddim_steps=ddim_steps, eta=ddim_eta,
|
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||||
|
|
Loading…
Reference in a new issue