playing with upscaler
This commit is contained in:
parent
f6016af80a
commit
71634a21c7
2 changed files with 10 additions and 9 deletions
|
@ -1473,9 +1473,10 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
||||||
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
||||||
zx, noise_level = self.low_scale_model(x_low)
|
zx, noise_level = self.low_scale_model(x_low)
|
||||||
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
|
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
|
||||||
|
#import pudb; pu.db
|
||||||
if log_mode:
|
if log_mode:
|
||||||
# TODO: maybe disable if too expensive
|
# TODO: maybe disable if too expensive
|
||||||
interpretability = True
|
interpretability = False
|
||||||
if interpretability:
|
if interpretability:
|
||||||
zx = zx[:, :, ::2, ::2]
|
zx = zx[:, :, ::2, ::2]
|
||||||
x_low_rec = self.low_scale_model.decode(zx)
|
x_low_rec = self.low_scale_model.decode(zx)
|
||||||
|
@ -1553,13 +1554,13 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
||||||
if k == "c_crossattn":
|
if k == "c_crossattn":
|
||||||
assert isinstance(c[k], list) and len(c[k]) == 1
|
assert isinstance(c[k], list) and len(c[k]) == 1
|
||||||
uc[k] = [uc_tmp]
|
uc[k] = [uc_tmp]
|
||||||
elif k == "c_adm":
|
elif k == "c_adm": # todo: only run with text-based guidance?
|
||||||
assert isinstance(c[k], torch.Tensor)
|
assert isinstance(c[k], torch.Tensor)
|
||||||
uc[k] = torch.ones_like(c[k]) * (self.low_scale_model.max_noise_level-1)
|
uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
|
||||||
elif isinstance(c[k], list):
|
elif isinstance(c[k], list):
|
||||||
uc[k] = [torch.zeros_like(c[k][i]) for i in range(len(c[k]))]
|
uc[k] = [c[k][i] for i in range(len(c[k]))]
|
||||||
else:
|
else:
|
||||||
uc[k] = torch.zeros_like(c[k])
|
uc[k] = c[k]
|
||||||
|
|
||||||
with ema_scope("Sampling with classifier-free guidance"):
|
with ema_scope("Sampling with classifier-free guidance"):
|
||||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||||
|
@ -1628,8 +1629,7 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
||||||
new_entry[:, :self.keep_dims, ...] = sd[k]
|
new_entry[:, :self.keep_dims, ...] = sd[k]
|
||||||
sd[k] = new_entry
|
sd[k] = new_entry
|
||||||
|
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False)
|
||||||
sd, strict=False)
|
|
||||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
print(f"Missing Keys: {missing}")
|
print(f"Missing Keys: {missing}")
|
||||||
|
|
|
@ -255,8 +255,9 @@ class LowScaleEncoder(nn.Module):
|
||||||
z = z * self.scale_factor
|
z = z * self.scale_factor
|
||||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||||
z = self.q_sample(z, noise_level)
|
z = self.q_sample(z, noise_level)
|
||||||
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
if self.out_size is not None:
|
||||||
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||||
|
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||||
return z, noise_level
|
return z, noise_level
|
||||||
|
|
||||||
def decode(self, z):
|
def decode(self, z):
|
||||||
|
|
Loading…
Reference in a new issue