Combine the two classifier-free guidance model outputs into a single batch

This commit is contained in:
Katherine Crowson 2022-04-05 11:35:05 -07:00
parent f0c4e092c1
commit 66df437e52
2 changed files with 14 additions and 9 deletions

View file

@ -166,11 +166,14 @@ class DDIMSampler(object):
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None): unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)
if unconditional_guidance_scale != 1.: if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
assert unconditional_conditioning is not None e_t = self.model.apply_model(x, t, c)
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning) else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None: if score_corrector is not None:

View file

@ -176,11 +176,13 @@ class PLMSSampler(object):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
def get_model_output(x, t): def get_model_output(x, t):
e_t = self.model.apply_model(x, t, c) if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
if unconditional_guidance_scale != 1.: else:
assert unconditional_conditioning is not None x_in = torch.cat([x] * 2)
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning) t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None: if score_corrector is not None: