add logging for upscaler
This commit is contained in:
parent
c89452ef9a
commit
255d479088
5 changed files with 113 additions and 11 deletions
|
@ -22,10 +22,11 @@ model:
|
||||||
low_scale_config:
|
low_scale_config:
|
||||||
target: ldm.modules.encoders.modules.LowScaleEncoder
|
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||||
params:
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
timesteps: 1000
|
timesteps: 1000
|
||||||
max_noise_level: 250
|
max_noise_level: 100
|
||||||
output_size: 64
|
output_size: 64
|
||||||
model_config:
|
model_config:
|
||||||
target: ldm.models.autoencoder.AutoencoderKL
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
@ -160,13 +161,14 @@ lightning:
|
||||||
increase_log_steps: False
|
increase_log_steps: False
|
||||||
log_first_step: False
|
log_first_step: False
|
||||||
log_images_kwargs:
|
log_images_kwargs:
|
||||||
|
sample: False
|
||||||
use_ema_scope: False
|
use_ema_scope: False
|
||||||
inpaint: False
|
inpaint: False
|
||||||
plot_progressive_rows: False
|
plot_progressive_rows: False
|
||||||
plot_diffusion_rows: False
|
plot_diffusion_rows: False
|
||||||
N: 4
|
N: 4
|
||||||
unconditional_guidance_scale: 3.0
|
#unconditional_guidance_scale: 3.0
|
||||||
unconditional_guidance_label: [""]
|
#unconditional_guidance_label: [""]
|
||||||
|
|
||||||
trainer:
|
trainer:
|
||||||
benchmark: True
|
benchmark: True
|
||||||
|
|
|
@ -20,6 +20,7 @@ model:
|
||||||
low_scale_config:
|
low_scale_config:
|
||||||
target: ldm.modules.encoders.modules.LowScaleEncoder
|
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||||
params:
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
timesteps: 1000
|
timesteps: 1000
|
||||||
|
|
|
@ -80,9 +80,12 @@ class DDIMSampler(object):
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
|
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||||
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if conditioning.shape[0] != batch_size:
|
if conditioning.shape[0] != batch_size:
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
|
@ -1278,10 +1278,10 @@ class LatentDiffusion(DDPM):
|
||||||
xc = self.cond_stage_model.decode(c)
|
xc = self.cond_stage_model.decode(c)
|
||||||
log["conditioning"] = xc
|
log["conditioning"] = xc
|
||||||
elif self.cond_stage_key in ["caption", "txt"]:
|
elif self.cond_stage_key in ["caption", "txt"]:
|
||||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key])
|
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
|
||||||
log["conditioning"] = xc
|
log["conditioning"] = xc
|
||||||
elif self.cond_stage_key == 'class_label':
|
elif self.cond_stage_key == 'class_label':
|
||||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
|
||||||
log['conditioning'] = xc
|
log['conditioning'] = xc
|
||||||
elif isimage(xc):
|
elif isimage(xc):
|
||||||
log["conditioning"] = xc
|
log["conditioning"] = xc
|
||||||
|
@ -1463,16 +1463,105 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_input(self, batch, k, cond_key=None, bs=None):
|
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
|
||||||
z, c, x = super().get_input(batch, k, return_x=True, force_c_encode=True, bs=bs)
|
if not log_mode:
|
||||||
x_low = batch[self.low_scale_key]
|
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
|
||||||
|
else:
|
||||||
|
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
|
||||||
|
force_c_encode=True, return_original_cond=True, bs=bs)
|
||||||
|
x_low = batch[self.low_scale_key][:bs]
|
||||||
x_low = rearrange(x_low, 'b h w c -> b c h w')
|
x_low = rearrange(x_low, 'b h w c -> b c h w')
|
||||||
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
||||||
zx, noise_level = self.low_scale_model(x_low)
|
zx, noise_level = self.low_scale_model(x_low)
|
||||||
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
|
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
|
||||||
|
if log_mode:
|
||||||
|
# TODO: maybe disable if too expensive
|
||||||
|
x_low_rec = self.low_scale_model.decode(zx)
|
||||||
|
return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
|
||||||
return z, all_conds
|
return z, all_conds
|
||||||
|
|
||||||
# TODO log it
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
||||||
|
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
|
||||||
|
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
|
||||||
|
**kwargs):
|
||||||
|
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
||||||
|
use_ddim = ddim_steps is not None
|
||||||
|
|
||||||
|
log = dict()
|
||||||
|
z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
|
||||||
|
log_mode=True)
|
||||||
|
N = min(x.shape[0], N)
|
||||||
|
n_row = min(x.shape[0], n_row)
|
||||||
|
log["inputs"] = x
|
||||||
|
log["reconstruction"] = xrec
|
||||||
|
log["x_lr"] = x_low
|
||||||
|
log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
|
||||||
|
if self.model.conditioning_key is not None:
|
||||||
|
if hasattr(self.cond_stage_model, "decode"):
|
||||||
|
xc = self.cond_stage_model.decode(c)
|
||||||
|
log["conditioning"] = xc
|
||||||
|
elif self.cond_stage_key in ["caption", "txt"]:
|
||||||
|
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
|
||||||
|
log["conditioning"] = xc
|
||||||
|
elif self.cond_stage_key == 'class_label':
|
||||||
|
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
|
||||||
|
log['conditioning'] = xc
|
||||||
|
elif isimage(xc):
|
||||||
|
log["conditioning"] = xc
|
||||||
|
if ismap(xc):
|
||||||
|
log["original_conditioning"] = self.to_rgb(xc)
|
||||||
|
|
||||||
|
if plot_diffusion_rows:
|
||||||
|
# get diffusion row
|
||||||
|
diffusion_row = list()
|
||||||
|
z_start = z[:n_row]
|
||||||
|
for t in range(self.num_timesteps):
|
||||||
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
|
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
||||||
|
t = t.to(self.device).long()
|
||||||
|
noise = torch.randn_like(z_start)
|
||||||
|
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
||||||
|
diffusion_row.append(self.decode_first_stage(z_noisy))
|
||||||
|
|
||||||
|
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
||||||
|
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
||||||
|
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
||||||
|
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
||||||
|
log["diffusion_row"] = diffusion_grid
|
||||||
|
|
||||||
|
if sample:
|
||||||
|
# get denoise row
|
||||||
|
with ema_scope("Sampling"):
|
||||||
|
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||||
|
ddim_steps=ddim_steps, eta=ddim_eta)
|
||||||
|
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||||
|
x_samples = self.decode_first_stage(samples)
|
||||||
|
log["samples"] = x_samples
|
||||||
|
if plot_denoise_rows:
|
||||||
|
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||||
|
log["denoise_row"] = denoise_grid
|
||||||
|
|
||||||
|
if unconditional_guidance_scale > 1.0:
|
||||||
|
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||||
|
with ema_scope("Sampling with classifier-free guidance"):
|
||||||
|
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||||
|
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
)
|
||||||
|
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||||
|
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||||
|
|
||||||
|
if plot_progressive_rows:
|
||||||
|
with ema_scope("Plotting Progressives"):
|
||||||
|
img, progressives = self.progressive_denoising(c,
|
||||||
|
shape=(self.channels, self.image_size, self.image_size),
|
||||||
|
batch_size=N)
|
||||||
|
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
|
||||||
|
log["progressive_row"] = prog_row
|
||||||
|
|
||||||
|
return log
|
||||||
|
|
||||||
|
|
||||||
class Layout2ImgDiffusion(LatentDiffusion):
|
class Layout2ImgDiffusion(LatentDiffusion):
|
||||||
|
|
|
@ -208,13 +208,15 @@ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_t
|
||||||
|
|
||||||
|
|
||||||
class LowScaleEncoder(nn.Module):
|
class LowScaleEncoder(nn.Module):
|
||||||
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64):
|
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
|
||||||
|
scale_factor=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_noise_level = max_noise_level
|
self.max_noise_level = max_noise_level
|
||||||
self.model = instantiate_from_config(model_config)
|
self.model = instantiate_from_config(model_config)
|
||||||
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
|
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
|
||||||
linear_end=linear_end)
|
linear_end=linear_end)
|
||||||
self.out_size = output_size
|
self.out_size = output_size
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
|
@ -250,12 +252,17 @@ class LowScaleEncoder(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
z = self.model.encode(x).sample()
|
z = self.model.encode(x).sample()
|
||||||
|
z = z * self.scale_factor
|
||||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||||
z = self.q_sample(z, noise_level)
|
z = self.q_sample(z, noise_level)
|
||||||
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||||
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||||
return z, noise_level
|
return z, noise_level
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
z = z / self.scale_factor
|
||||||
|
return self.model.decode(z)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from ldm.util import count_params
|
from ldm.util import count_params
|
||||||
|
|
Loading…
Reference in a new issue