experimental support for guided upscale sampling
This commit is contained in:
parent
255d479088
commit
3c77204c0a
3 changed files with 27 additions and 3 deletions
|
@ -24,7 +24,7 @@ model:
|
|||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
timesteps: 1000
|
||||
max_noise_level: 250
|
||||
max_noise_level: 100
|
||||
output_size: 64
|
||||
model_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
|
|
|
@ -179,7 +179,20 @@ class DDIMSampler(object):
|
|||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
unconditional_conditioning[k][i],
|
||||
c[k][i]]) for i in range(len(c[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([
|
||||
unconditional_conditioning[k],
|
||||
c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
|
|
|
@ -1543,7 +1543,18 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
|||
log["denoise_row"] = denoise_grid
|
||||
|
||||
if unconditional_guidance_scale > 1.0:
|
||||
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||
uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||
# TODO explore better "unconditional" choices for the other keys
|
||||
uc = dict()
|
||||
for k in c:
|
||||
if k == "c_crossattn":
|
||||
assert isinstance(c[k], list) and len(c[k]) == 1
|
||||
uc[k] = [uc_tmp]
|
||||
elif isinstance(c[k], list):
|
||||
uc[k] = [torch.zeros_like(c[k][i]) for i in range(len(c[k]))]
|
||||
else:
|
||||
uc[k] = torch.zeros_like(c[k])
|
||||
|
||||
with ema_scope("Sampling with classifier-free guidance"):
|
||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||
|
|
Loading…
Reference in a new issue