From 66df437e52826a5149a1c20dcc9f0be0abd0f685 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 5 Apr 2022 11:35:05 -0700 Subject: [PATCH] Combine the two classifier-free guidance model outputs into a single batch --- ldm/models/diffusion/ddim.py | 11 +++++++---- ldm/models/diffusion/plms.py | 12 +++++++----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index b8ff309..edf1eaf 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -166,11 +166,14 @@ class DDIMSampler(object): temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None): b, *_, device = *x.shape, x.device - e_t = self.model.apply_model(x, t, c) - if unconditional_guidance_scale != 1.: - assert unconditional_conditioning is not None - e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning) + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) if score_corrector is not None: diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 91a792f..78eeb10 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -176,11 +176,13 @@ class PLMSSampler(object): b, *_, device = *x.shape, x.device def get_model_output(x, t): - e_t = self.model.apply_model(x, t, c) - - if unconditional_guidance_scale != 1.: - assert unconditional_conditioning is not None - e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning) + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) if score_corrector is not None: