playing with upscaler
This commit is contained in:
parent
f6016af80a
commit
71634a21c7
2 changed files with 10 additions and 9 deletions
|
@ -1473,9 +1473,10 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
|||
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
||||
zx, noise_level = self.low_scale_model(x_low)
|
||||
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
|
||||
#import pudb; pu.db
|
||||
if log_mode:
|
||||
# TODO: maybe disable if too expensive
|
||||
interpretability = True
|
||||
interpretability = False
|
||||
if interpretability:
|
||||
zx = zx[:, :, ::2, ::2]
|
||||
x_low_rec = self.low_scale_model.decode(zx)
|
||||
|
@ -1553,13 +1554,13 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
|||
if k == "c_crossattn":
|
||||
assert isinstance(c[k], list) and len(c[k]) == 1
|
||||
uc[k] = [uc_tmp]
|
||||
elif k == "c_adm":
|
||||
elif k == "c_adm": # todo: only run with text-based guidance?
|
||||
assert isinstance(c[k], torch.Tensor)
|
||||
uc[k] = torch.ones_like(c[k]) * (self.low_scale_model.max_noise_level-1)
|
||||
uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
|
||||
elif isinstance(c[k], list):
|
||||
uc[k] = [torch.zeros_like(c[k][i]) for i in range(len(c[k]))]
|
||||
uc[k] = [c[k][i] for i in range(len(c[k]))]
|
||||
else:
|
||||
uc[k] = torch.zeros_like(c[k])
|
||||
uc[k] = c[k]
|
||||
|
||||
with ema_scope("Sampling with classifier-free guidance"):
|
||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||
|
@ -1628,8 +1629,7 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
|||
new_entry[:, :self.keep_dims, ...] = sd[k]
|
||||
sd[k] = new_entry
|
||||
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||
sd, strict=False)
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
|
|
|
@ -255,8 +255,9 @@ class LowScaleEncoder(nn.Module):
|
|||
z = z * self.scale_factor
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
z = self.q_sample(z, noise_level)
|
||||
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
if self.out_size is not None:
|
||||
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
def decode(self, z):
|
||||
|
|
Loading…
Reference in a new issue