diff --git a/scripts/demo/inpainting.py b/scripts/demo/inpainting.py index 9ee31f8..d233d7f 100644 --- a/scripts/demo/inpainting.py +++ b/scripts/demo/inpainting.py @@ -114,20 +114,22 @@ def load_model_from_config(config, ckpt, verbose=False): return model -if __name__ == "__main__": +def run( + config="/fsx/stable-diffusion/stable-diffusion/configs/stable-diffusion/inpainting/v1-finetune-for-inpainting-laion-iaesthe.yaml", + #ckpt="/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-07-28T07-44-05_v1-finetune-for-inpainting-laion-aesthetic-larger-masks/checkpoints/last.ckpt", + ckpt="/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-08-01T08-52-14_v1-finetune-for-inpainting-laion-aesthetic-larger-masks-and-ucfg/checkpoints/last.ckpt", + ): st.title("Stable Inpainting") state = init() if not "model" in state: - config = "/fsx/stable-diffusion/stable-diffusion/configs/stable-diffusion/inpainting/v1-finetune-for-inpainting-laion-iaesthe.yaml" - ckpt = "/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-07-24T16-01-27_v1-finetune-for-inpainting-laion-iaesthe/checkpoints/last.ckpt" config = OmegaConf.load(config) model = load_model_from_config(config, ckpt) state["model"] = model uploaded_file = st.file_uploader("Upload image to inpaint") if uploaded_file is not None: - image = Image.open(io.BytesIO(uploaded_file.getvalue())) + image = Image.open(io.BytesIO(uploaded_file.getvalue())).convert("RGB") width, height = image.size smaller = min(width, height) crop = ( @@ -206,3 +208,8 @@ if __name__ == "__main__": ) st.text("Samples") st.image(samples[0]) + + +if __name__ == "__main__": + import fire + fire.Fire(run) diff --git a/scripts/inpaint_sd.py b/scripts/inpaint_sd.py new file mode 100644 index 0000000..568a0bc --- /dev/null +++ b/scripts/inpaint_sd.py @@ -0,0 +1,190 @@ +import argparse, os, sys, glob +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm +import numpy as np +import torch +from main import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler + + +def make_batch_ldm(image, mask, device): + image = np.array(Image.open(image).convert("RGB")) + image = image.astype(np.float32)/255.0 + image = image[None].transpose(0,3,1,2) + image = torch.from_numpy(image) + + mask = np.array(Image.open(mask).convert("L")) + mask = mask.astype(np.float32)/255.0 + mask = mask[None,None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = (1-mask)*image + + batch = {"image": image, "mask": mask, "masked_image": masked_image} + for k in batch: + batch[k] = batch[k].to(device=device) + batch[k] = batch[k]*2.0-1.0 + return batch + + +def make_batch_sd( + image, + mask, + txt, + device): + # image hwc in -1 1 + image = np.array(Image.open(image).convert("RGB")) + image = image[None].transpose(0,3,1,2) + image = torch.from_numpy(image).to(dtype=torch.float32)/127.5-1.0 + + mask = np.array(Image.open(mask).convert("L")) + mask = mask.astype(np.float32)/255.0 + mask = mask[None,None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + batch = { + "jpg": image.to(device=device), + "txt": [txt], + "mask": mask.to(device=device), + "masked_image": masked_image.to(device=device), + } + return batch + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--indir", + type=str, + nargs="?", + help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--eta", + type=float, + default=0.0, + help="eta of ddim", + ) + parser.add_argument( + "--scale", + type=float, + default=6.0, + help="scale of unconditional guidance", + ) + parser.add_argument( + "--worldsize", + type=int, + default=1, + help="scale of unconditional guidance", + ) + parser.add_argument( + "--rank", + type=int, + default=0, + help="scale of unconditional guidance", + ) + parser.add_argument( + "--ckpt", + type=str, + default="/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-08-01T08-52-14_v1-finetune-for-inpainting-laion-aesthetic-larger-masks-and-ucfg/checkpoints/last.ckpt", + help="scale of unconditional guidance", + ) + opt = parser.parse_args() + + assert opt.rank < opt.worldsize + + mstr = "mask000.png" + masks = sorted(glob.glob(os.path.join(opt.indir, f"*_{mstr}"))) + images = [x.replace(f"_{mstr}", ".png") for x in masks] + print(f"Found {len(masks)} inputs.") + + #config = "models/ldm/inpainting_big/config.yaml" + config="/fsx/stable-diffusion/stable-diffusion/configs/stable-diffusion/inpainting/v1-finetune-for-inpainting-laion-iaesthe.yaml" + config = OmegaConf.load(config) + model = instantiate_from_config(config.model) + #ckpt="models/ldm/inpainting_big/last.ckpt" + ckpt=opt.ckpt + model.load_state_dict(torch.load(ckpt)["state_dict"], + strict=False) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + indices = [i for i in range(len(images)) if i % opt.worldsize == opt.rank] + images = [images[i] for i in indices] + masks = [masks[i] for i in indices] + + os.makedirs(opt.outdir, exist_ok=True) + with torch.no_grad(): + with model.ema_scope(): + for image, mask in tqdm(zip(images, masks), total=len(images)): + outpath = os.path.join(opt.outdir, os.path.split(image)[1]) + #batch = make_batch_ldm(image, mask, device=device) + + ##### unroll + batch = make_batch_sd(image, mask, txt="photograph of a beautiful empty scene, highest quality settings", + device=device) + + c = model.cond_stage_model.encode(batch["txt"]) + + c_cat = list() + for ck in model.concat_keys: + cc = batch[ck].float() + if ck != model.masked_image_key: + bchw = (1, 4, 64, 64) + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = model.get_first_stage_encoding(model.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + + # cond + cond={"c_concat": [c_cat], "c_crossattn": [c]} + + # uncond cond + uc_cross = model.get_unconditional_conditioning(1, "") + uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} + + shape = (model.channels, model.image_size, model.image_size) + samples_cfg, intermediates = sampler.sample( + opt.steps, + 1, + shape, + cond, + verbose=False, + eta=opt.eta, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc_full, + ) + x_samples_ddim = model.decode_first_stage(samples_cfg) + + image = torch.clamp((batch["jpg"]+1.0)/2.0, + min=0.0, max=1.0) + mask = torch.clamp((batch["mask"]+1.0)/2.0, + min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, + min=0.0, max=1.0) + + inpainted = (1-mask)*image+mask*predicted_image + inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + Image.fromarray(inpainted.astype(np.uint8)).save(outpath)