Combine the two classifier-free guidance model outputs into a single batch
This commit is contained in:
parent
f0c4e092c1
commit
66df437e52
2 changed files with 14 additions and 9 deletions
|
@ -166,11 +166,14 @@ class DDIMSampler(object):
|
|||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
|
||||
if unconditional_guidance_scale != 1.:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning)
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
|
|
|
@ -176,11 +176,13 @@ class PLMSSampler(object):
|
|||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
|
||||
if unconditional_guidance_scale != 1.:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning)
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
|
|
Loading…
Reference in a new issue