Combine the two classifier-free guidance model outputs into a single batch

This commit is contained in:
Katherine Crowson 2022-04-05 11:35:05 -07:00
parent f0c4e092c1
commit 66df437e52
2 changed files with 14 additions and 9 deletions

View file

@ -166,11 +166,14 @@ class DDIMSampler(object):
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)
if unconditional_guidance_scale != 1.:
assert unconditional_conditioning is not None
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning)
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:

View file

@ -176,11 +176,13 @@ class PLMSSampler(object):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
if unconditional_guidance_scale != 1.:
assert unconditional_conditioning is not None
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None: