add logging for upscaler
This commit is contained in:
parent
c89452ef9a
commit
255d479088
5 changed files with 113 additions and 11 deletions
|
@ -22,10 +22,11 @@ model:
|
|||
low_scale_config:
|
||||
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
timesteps: 1000
|
||||
max_noise_level: 250
|
||||
max_noise_level: 100
|
||||
output_size: 64
|
||||
model_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
|
@ -160,13 +161,14 @@ lightning:
|
|||
increase_log_steps: False
|
||||
log_first_step: False
|
||||
log_images_kwargs:
|
||||
sample: False
|
||||
use_ema_scope: False
|
||||
inpaint: False
|
||||
plot_progressive_rows: False
|
||||
plot_diffusion_rows: False
|
||||
N: 4
|
||||
unconditional_guidance_scale: 3.0
|
||||
unconditional_guidance_label: [""]
|
||||
#unconditional_guidance_scale: 3.0
|
||||
#unconditional_guidance_label: [""]
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
|
|
|
@ -20,6 +20,7 @@ model:
|
|||
low_scale_config:
|
||||
target: ldm.modules.encoders.modules.LowScaleEncoder
|
||||
params:
|
||||
scale_factor: 0.18215
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
timesteps: 1000
|
||||
|
|
|
@ -80,9 +80,12 @@ class DDIMSampler(object):
|
|||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
|
|
@ -1278,10 +1278,10 @@ class LatentDiffusion(DDPM):
|
|||
xc = self.cond_stage_model.decode(c)
|
||||
log["conditioning"] = xc
|
||||
elif self.cond_stage_key in ["caption", "txt"]:
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key])
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
|
||||
log["conditioning"] = xc
|
||||
elif self.cond_stage_key == 'class_label':
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
|
||||
log['conditioning'] = xc
|
||||
elif isimage(xc):
|
||||
log["conditioning"] = xc
|
||||
|
@ -1463,16 +1463,105 @@ class LatentUpscaleDiffusion(LatentDiffusion):
|
|||
param.requires_grad = False
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k, cond_key=None, bs=None):
|
||||
z, c, x = super().get_input(batch, k, return_x=True, force_c_encode=True, bs=bs)
|
||||
x_low = batch[self.low_scale_key]
|
||||
def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
|
||||
if not log_mode:
|
||||
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
|
||||
else:
|
||||
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
|
||||
force_c_encode=True, return_original_cond=True, bs=bs)
|
||||
x_low = batch[self.low_scale_key][:bs]
|
||||
x_low = rearrange(x_low, 'b h w c -> b c h w')
|
||||
x_low = x_low.to(memory_format=torch.contiguous_format).float()
|
||||
zx, noise_level = self.low_scale_model(x_low)
|
||||
all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
|
||||
if log_mode:
|
||||
# TODO: maybe disable if too expensive
|
||||
x_low_rec = self.low_scale_model.decode(zx)
|
||||
return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
|
||||
return z, all_conds
|
||||
|
||||
# TODO log it
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
|
||||
plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
|
||||
unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
|
||||
**kwargs):
|
||||
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
||||
use_ddim = ddim_steps is not None
|
||||
|
||||
log = dict()
|
||||
z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
|
||||
log_mode=True)
|
||||
N = min(x.shape[0], N)
|
||||
n_row = min(x.shape[0], n_row)
|
||||
log["inputs"] = x
|
||||
log["reconstruction"] = xrec
|
||||
log["x_lr"] = x_low
|
||||
log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
|
||||
if self.model.conditioning_key is not None:
|
||||
if hasattr(self.cond_stage_model, "decode"):
|
||||
xc = self.cond_stage_model.decode(c)
|
||||
log["conditioning"] = xc
|
||||
elif self.cond_stage_key in ["caption", "txt"]:
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25)
|
||||
log["conditioning"] = xc
|
||||
elif self.cond_stage_key == 'class_label':
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25)
|
||||
log['conditioning'] = xc
|
||||
elif isimage(xc):
|
||||
log["conditioning"] = xc
|
||||
if ismap(xc):
|
||||
log["original_conditioning"] = self.to_rgb(xc)
|
||||
|
||||
if plot_diffusion_rows:
|
||||
# get diffusion row
|
||||
diffusion_row = list()
|
||||
z_start = z[:n_row]
|
||||
for t in range(self.num_timesteps):
|
||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
||||
t = t.to(self.device).long()
|
||||
noise = torch.randn_like(z_start)
|
||||
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
||||
diffusion_row.append(self.decode_first_stage(z_noisy))
|
||||
|
||||
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
||||
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
||||
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
||||
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
||||
log["diffusion_row"] = diffusion_grid
|
||||
|
||||
if sample:
|
||||
# get denoise row
|
||||
with ema_scope("Sampling"):
|
||||
samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||
ddim_steps=ddim_steps, eta=ddim_eta)
|
||||
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||
x_samples = self.decode_first_stage(samples)
|
||||
log["samples"] = x_samples
|
||||
if plot_denoise_rows:
|
||||
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||
log["denoise_row"] = denoise_grid
|
||||
|
||||
if unconditional_guidance_scale > 1.0:
|
||||
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||
with ema_scope("Sampling with classifier-free guidance"):
|
||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
)
|
||||
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||
|
||||
if plot_progressive_rows:
|
||||
with ema_scope("Plotting Progressives"):
|
||||
img, progressives = self.progressive_denoising(c,
|
||||
shape=(self.channels, self.image_size, self.image_size),
|
||||
batch_size=N)
|
||||
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
|
||||
log["progressive_row"] = prog_row
|
||||
|
||||
return log
|
||||
|
||||
|
||||
class Layout2ImgDiffusion(LatentDiffusion):
|
||||
|
|
|
@ -208,13 +208,15 @@ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_t
|
|||
|
||||
|
||||
class LowScaleEncoder(nn.Module):
|
||||
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64):
|
||||
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
|
||||
scale_factor=1.0):
|
||||
super().__init__()
|
||||
self.max_noise_level = max_noise_level
|
||||
self.model = instantiate_from_config(model_config)
|
||||
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
|
||||
linear_end=linear_end)
|
||||
self.out_size = output_size
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
|
@ -250,12 +252,17 @@ class LowScaleEncoder(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
z = self.model.encode(x).sample()
|
||||
z = z * self.scale_factor
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
z = self.q_sample(z, noise_level)
|
||||
#z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||
z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
def decode(self, z):
|
||||
z = z / self.scale_factor
|
||||
return self.model.decode(z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ldm.util import count_params
|
||||
|
|
Loading…
Reference in a new issue