paint it in

This commit is contained in:
Robin Rombach 2022-07-24 13:24:30 +02:00
parent 76e2f4b739
commit 57c3a76346
1 changed files with 162 additions and 1 deletions

View File

@ -1260,7 +1260,6 @@ class LatentDiffusion(DDPM):
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
@ -1582,6 +1581,168 @@ class LatentUpscaleDiffusion(LatentDiffusion):
return log
class LatentInpaintDiffusion(LatentDiffusion):
"""
can either run as pure inpainting model (only concat mode) or with mixed conditionings,
e.g. mask as concat and text via cross-attn.
To disable finetuning mode, set finetune_keys to None
"""
def __init__(self,
finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", "model_ema.diffusion_modelinput_blocks00weight"),
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels
*args, **kwargs
):
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", list())
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.finetune_keys = finetune_keys
self.concat_keys = concat_keys
self.keep_dims = keep_finetune_dims
if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
if exists(ckpt_path):
self.init_from_ckpt(ckpt_path, ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
# make it explicit, finetune by including extra input channels
if exists(self.finetune_keys) and k in self.finetune_keys:
new_entry = None
for name, param in self.named_parameters():
if name in self.finetune_keys:
print(f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
new_entry = torch.zeros_like(param) # zero init
assert exists(new_entry), 'did not find matching parameter to modify'
new_entry[:, :self.keep_dims, ...] = sd[k]
sd[k] = new_entry
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
@torch.no_grad()
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpaiting'
z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
force_c_encode=True, return_original_cond=True, bs=bs)
assert exists(self.concat_keys)
c_cat = list()
for ck in self.concat_keys:
cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
bchw = z.shape
if ck != self.masked_image_key:
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
else:
cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
uc_cat = c_cat
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N, ddim=use_ddim,
ddim_steps=ddim_steps, eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
log["masked_image"] = rearrange(batch["masked_image"],
'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
return log
class Layout2ImgDiffusion(LatentDiffusion):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):