add logging for upscaler

This commit is contained in:
rromb 2022-06-13 10:43:41 +02:00
parent c89452ef9a
commit 255d479088
5 changed files with 113 additions and 11 deletions

View file

@ -22,10 +22,11 @@ model:
low_scale_config: low_scale_config:
target: ldm.modules.encoders.modules.LowScaleEncoder target: ldm.modules.encoders.modules.LowScaleEncoder
params: params:
scale_factor: 0.18215
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
timesteps: 1000 timesteps: 1000
max_noise_level: 250 max_noise_level: 100
output_size: 64 output_size: 64
model_config: model_config:
target: ldm.models.autoencoder.AutoencoderKL target: ldm.models.autoencoder.AutoencoderKL
@ -160,13 +161,14 @@ lightning:
increase_log_steps: False increase_log_steps: False
log_first_step: False log_first_step: False
log_images_kwargs: log_images_kwargs:
sample: False
use_ema_scope: False use_ema_scope: False
inpaint: False inpaint: False
plot_progressive_rows: False plot_progressive_rows: False
plot_diffusion_rows: False plot_diffusion_rows: False
N: 4 N: 4
unconditional_guidance_scale: 3.0 #unconditional_guidance_scale: 3.0
unconditional_guidance_label: [""] #unconditional_guidance_label: [""]
trainer: trainer:
benchmark: True benchmark: True

View file

@ -20,6 +20,7 @@ model:
low_scale_config: low_scale_config:
target: ldm.modules.encoders.modules.LowScaleEncoder target: ldm.modules.encoders.modules.LowScaleEncoder
params: params:
scale_factor: 0.18215
linear_start: 0.00085 linear_start: 0.00085
linear_end: 0.0120 linear_end: 0.0120
timesteps: 1000 timesteps: 1000

View file

@ -80,9 +80,12 @@ class DDIMSampler(object):
): ):
if conditioning is not None: if conditioning is not None:
if isinstance(conditioning, dict): if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0] ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size: if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else: else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

View file

@ -1278,10 +1278,10 @@ class LatentDiffusion(DDPM):
xc = self.cond_stage_model.decode(c) xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]: elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key]) xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
log["conditioning"] = xc log["conditioning"] = xc
elif self.cond_stage_key == 'class_label': elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
log['conditioning'] = xc log['conditioning'] = xc
elif isimage(xc): elif isimage(xc):
log["conditioning"] = xc log["conditioning"] = xc
@ -1463,16 +1463,105 @@ class LatentUpscaleDiffusion(LatentDiffusion):
param.requires_grad = False param.requires_grad = False
@torch.no_grad() @torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None): def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
z, c, x = super().get_input(batch, k, return_x=True, force_c_encode=True, bs=bs) if not log_mode:
x_low = batch[self.low_scale_key] z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
else:
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
x_low = batch[self.low_scale_key][:bs]
x_low = rearrange(x_low, 'b h w c -> b c h w') x_low = rearrange(x_low, 'b h w c -> b c h w')
x_low = x_low.to(memory_format=torch.contiguous_format).float() x_low = x_low.to(memory_format=torch.contiguous_format).float()
zx, noise_level = self.low_scale_model(x_low) zx, noise_level = self.low_scale_model(x_low)
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
if log_mode:
# TODO: maybe disable if too expensive
x_low_rec = self.low_scale_model.decode(zx)
return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
return z, all_conds return z, all_conds
# TODO log it @torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
log_mode=True)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
log["x_lr"] = x_low
log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
return log
class Layout2ImgDiffusion(LatentDiffusion): class Layout2ImgDiffusion(LatentDiffusion):

View file

@ -208,13 +208,15 @@ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_t
class LowScaleEncoder(nn.Module): class LowScaleEncoder(nn.Module):
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64): def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
scale_factor=1.0):
super().__init__() super().__init__()
self.max_noise_level = max_noise_level self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config) self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start, self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
linear_end=linear_end) linear_end=linear_end)
self.out_size = output_size self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(self, beta_schedule="linear", timesteps=1000, def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
@ -250,12 +252,17 @@ class LowScaleEncoder(nn.Module):
def forward(self, x): def forward(self, x):
z = self.model.encode(x).sample() z = self.model.encode(x).sample()
z = z * self.scale_factor
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
z = self.q_sample(z, noise_level) z = self.q_sample(z, noise_level)
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode #z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level return z, noise_level
def decode(self, z):
z = z / self.scale_factor
return self.model.decode(z)
if __name__ == "__main__": if __name__ == "__main__":
from ldm.util import count_params from ldm.util import count_params