diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 0b9bc98..aa8b4ff 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -1473,9 +1473,10 @@ class LatentUpscaleDiffusion(LatentDiffusion): x_low = x_low.to(memory_format=torch.contiguous_format).float() zx, noise_level = self.low_scale_model(x_low) all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + #import pudb; pu.db if log_mode: # TODO: maybe disable if too expensive - interpretability = True + interpretability = False if interpretability: zx = zx[:, :, ::2, ::2] x_low_rec = self.low_scale_model.decode(zx) @@ -1553,13 +1554,13 @@ class LatentUpscaleDiffusion(LatentDiffusion): if k == "c_crossattn": assert isinstance(c[k], list) and len(c[k]) == 1 uc[k] = [uc_tmp] - elif k == "c_adm": + elif k == "c_adm": # todo: only run with text-based guidance? assert isinstance(c[k], torch.Tensor) - uc[k] = torch.ones_like(c[k]) * (self.low_scale_model.max_noise_level-1) + uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level elif isinstance(c[k], list): - uc[k] = [torch.zeros_like(c[k][i]) for i in range(len(c[k]))] + uc[k] = [c[k][i] for i in range(len(c[k]))] else: - uc[k] = torch.zeros_like(c[k]) + uc[k] = c[k] with ema_scope("Sampling with classifier-free guidance"): samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, @@ -1628,8 +1629,7 @@ class LatentInpaintDiffusion(LatentDiffusion): new_entry[:, :self.keep_dims, ...] = sd[k] sd[k] = new_entry - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing Keys: {missing}") diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 21ba5b4..84e5604 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -255,8 +255,9 @@ class LowScaleEncoder(nn.Module): z = z * self.scale_factor noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() z = self.q_sample(z, noise_level) - #z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode - z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) + if self.out_size is not None: + z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode + # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) return z, noise_level def decode(self, z):