From 8ce2c914d306f5a3cae5eea4143b055ebdcca253 Mon Sep 17 00:00:00 2001 From: Patrick Esser Date: Tue, 26 Jul 2022 08:24:02 +0000 Subject: [PATCH] st inpainting --- scripts/demo/inpainting.py | 208 +++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 scripts/demo/inpainting.py diff --git a/scripts/demo/inpainting.py b/scripts/demo/inpainting.py new file mode 100644 index 0000000..9ee31f8 --- /dev/null +++ b/scripts/demo/inpainting.py @@ -0,0 +1,208 @@ +import streamlit as st +import torch +import cv2 +import numpy as np +from ldm.util import instantiate_from_config +from omegaconf import OmegaConf +from PIL import Image +import io +from streamlit_drawable_canvas import st_canvas + + +torch.set_grad_enabled(False) + + +def sample( + model, + prompt, + n_runs=3, + n_samples=2, + H=512, + W=512, + scale=10.0, + ddim_steps=50, + callback=None, + image=None, + mask=None, + ): + batch = np2batch(image=image, mask=mask, txt=prompt) + + self = model + unconditional_guidance_scale = scale + unconditional_guidance_label = [""] + use_ddim = True + ddim_eta = 0 + N = 1 + ema_scope = self.ema_scope + + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + samples = self.decode_first_stage(samples_cfg) + else: + raise ValueError() + + samples = torch2np(samples) + return samples + + +def np2batch( + image, + mask, + txt): + print("###") + print(image.shape) + print(mask.shape) + print("###") + # image hwc in -1 1 + image = torch.from_numpy(image).to(dtype=torch.float32)/127.5-1.0 + + mask[mask < 0.5] = 0 + mask[mask > 0.5] = 1 + mask = torch.from_numpy(mask)[:,:,:1] + masked_image = image * (mask < 0.5) + + batch = { + "jpg": image[None], + "txt": [txt], + "mask": mask[None], + "masked_image": masked_image[None], + } + return batch + + +def torch2np(x): + x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8) + x = x.permute(0, 2, 3, 1).detach().cpu().numpy() + return x + + +@st.cache(allow_output_mutation=True) +def init(): + state = dict() + return state + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd.get("global_step", "?") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + print(f"Loaded global step {global_step}") + return model + + +if __name__ == "__main__": + st.title("Stable Inpainting") + state = init() + + if not "model" in state: + config = "/fsx/stable-diffusion/stable-diffusion/configs/stable-diffusion/inpainting/v1-finetune-for-inpainting-laion-iaesthe.yaml" + ckpt = "/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-07-24T16-01-27_v1-finetune-for-inpainting-laion-iaesthe/checkpoints/last.ckpt" + config = OmegaConf.load(config) + model = load_model_from_config(config, ckpt) + state["model"] = model + + uploaded_file = st.file_uploader("Upload image to inpaint") + if uploaded_file is not None: + image = Image.open(io.BytesIO(uploaded_file.getvalue())) + width, height = image.size + smaller = min(width, height) + crop = ( + (width-smaller)//2, + (height-smaller)//2, + (width-smaller)//2+smaller, + (height-smaller)//2+smaller, + ) + image = image.crop(crop) + image = image.resize((512, 512)) + #st.write("Uploaded Image") + #st.image(image) + + st.write("Draw a mask (and send it to streamlit, button lower left)") + stroke_width = int(st.number_input("Stroke Width", value=50)) + canvas_result = st_canvas( + fill_color="rgba(255, 255, 255)", # Fixed fill color with some opacity + stroke_width=stroke_width, + stroke_color="rgb(0, 0, 0)", + background_color="rgb(0, 0, 0)", + background_image=image if image is not None else Image.fromarray(255*np.ones((512,512,3), + dtype=np.uint8)), + update_streamlit=False, + height=image.size[1] if image is not None else 512, + width=image.size[0] if image is not None else 512, + drawing_mode="freedraw", + point_display_radius=0, + key="canvas", + ) + if canvas_result: + mask = canvas_result.image_data + mask = np.array(mask)[:,:,[3,3,3]] + mask = mask > 127 + + # visualize + bdry = cv2.dilate(mask.astype(np.uint8), np.ones((3,3), dtype=np.uint8)) + bdry = (bdry > 0) & ~mask + + masked_image = np.array(image)*(1-mask) + mask*0.3*np.array(image) + masked_image[:,:,0][bdry[:,:,0]] = 255 + masked_image[:,:,1][bdry[:,:,1]] = 0 + masked_image[:,:,2][bdry[:,:,2]] = 0 + st.write("Masked Image") + st.image(Image.fromarray(masked_image.astype(np.uint8))) + + prompt = st.text_input("Prompt") + scale = float(st.number_input("Guidance", value=10.0)) + t_total = int(st.number_input("Diffusion steps", value=50)) + + if st.button("Sample"): + st.text("Sampling") + batch_progress = st.progress(0) + batch_total = 3 + t_progress = st.progress(0) + result = st.empty() + #canvas = make_canvas(2, 3) + def callback(x, batch, t): + #result.text(f"{batch}, {t}") + batch_progress.progress(min(1.0, (batch+1)/batch_total)) + t_progress.progress(min(1.0, (t+1)/t_total)) + update_canvas(canvas, x, batch) + result.image(canvas) + + samples = sample( + state["model"], + prompt, + n_runs=3, + n_samples=2, + H=512, + W=512, + scale=scale, + ddim_steps=t_total, + callback=callback, + image=np.array(image), + mask=np.array(mask), + ) + st.text("Samples") + st.image(samples[0])