playing with upscaler

This commit is contained in:
Robin Rombach 2022-07-28 00:08:46 +02:00
parent f6016af80a
commit 71634a21c7
2 changed files with 10 additions and 9 deletions

View File

@ -1473,9 +1473,10 @@ class LatentUpscaleDiffusion(LatentDiffusion):
x_low = x_low.to(memory_format=torch.contiguous_format).float()
zx, noise_level = self.low_scale_model(x_low)
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
#import pudb; pu.db
if log_mode:
# TODO: maybe disable if too expensive
interpretability = True
interpretability = False
if interpretability:
zx = zx[:, :, ::2, ::2]
x_low_rec = self.low_scale_model.decode(zx)
@ -1553,13 +1554,13 @@ class LatentUpscaleDiffusion(LatentDiffusion):
if k == "c_crossattn":
assert isinstance(c[k], list) and len(c[k]) == 1
uc[k] = [uc_tmp]
elif k == "c_adm":
elif k == "c_adm": # todo: only run with text-based guidance?
assert isinstance(c[k], torch.Tensor)
uc[k] = torch.ones_like(c[k]) * (self.low_scale_model.max_noise_level-1)
uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
elif isinstance(c[k], list):
uc[k] = [torch.zeros_like(c[k][i]) for i in range(len(c[k]))]
uc[k] = [c[k][i] for i in range(len(c[k]))]
else:
uc[k] = torch.zeros_like(c[k])
uc[k] = c[k]
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
@ -1628,8 +1629,7 @@ class LatentInpaintDiffusion(LatentDiffusion):
new_entry[:, :self.keep_dims, ...] = sd[k]
sd[k] = new_entry
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")

View File

@ -255,8 +255,9 @@ class LowScaleEncoder(nn.Module):
z = z * self.scale_factor
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
z = self.q_sample(z, noise_level)
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
if self.out_size is not None:
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
def decode(self, z):