import streamlit as st import torch import cv2 import numpy as np from ldm.util import instantiate_from_config from omegaconf import OmegaConf from PIL import Image import io from streamlit_drawable_canvas import st_canvas torch.set_grad_enabled(False) def sample( model, prompt, n_runs=3, n_samples=2, H=512, W=512, scale=10.0, ddim_steps=50, callback=None, image=None, mask=None, ): batch = np2batch(image=image, mask=mask, txt=prompt) self = model unconditional_guidance_scale = scale unconditional_guidance_label = [""] use_ddim = True ddim_eta = 0 N = 1 ema_scope = self.ema_scope z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) c_cat, c = c["c_concat"][0], c["c_crossattn"][0] if unconditional_guidance_scale > 1.0: uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) uc_cat = c_cat uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} with ema_scope("Sampling with classifier-free guidance"): samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=uc_full, ) samples = self.decode_first_stage(samples_cfg) else: raise ValueError() samples = torch2np(samples) return samples def np2batch( image, mask, txt): print("###") print(image.shape) print(mask.shape) print("###") # image hwc in -1 1 image = torch.from_numpy(image).to(dtype=torch.float32)/127.5-1.0 mask[mask < 0.5] = 0 mask[mask > 0.5] = 1 mask = torch.from_numpy(mask)[:,:,:1] masked_image = image * (mask < 0.5) batch = { "jpg": image[None], "txt": [txt], "mask": mask[None], "masked_image": masked_image[None], } return batch def torch2np(x): x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8) x = x.permute(0, 2, 3, 1).detach().cpu().numpy() return x @st.cache(allow_output_mutation=True) def init(): state = dict() return state def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd.get("global_step", "?") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) model.cuda() model.eval() print(f"Loaded global step {global_step}") return model if __name__ == "__main__": st.title("Stable Inpainting") state = init() if not "model" in state: config = "/fsx/stable-diffusion/stable-diffusion/configs/stable-diffusion/inpainting/v1-finetune-for-inpainting-laion-iaesthe.yaml" ckpt = "/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-07-24T16-01-27_v1-finetune-for-inpainting-laion-iaesthe/checkpoints/last.ckpt" config = OmegaConf.load(config) model = load_model_from_config(config, ckpt) state["model"] = model uploaded_file = st.file_uploader("Upload image to inpaint") if uploaded_file is not None: image = width, height = image.size smaller = min(width, height) crop = ( (width-smaller)//2, (height-smaller)//2, (width-smaller)//2+smaller, (height-smaller)//2+smaller, ) image = image.crop(crop) image = image.resize((512, 512)) #st.write("Uploaded Image") #st.image(image) st.write("Draw a mask (and send it to streamlit, button lower left)") stroke_width = int(st.number_input("Stroke Width", value=50)) canvas_result = st_canvas( fill_color="rgba(255, 255, 255)", # Fixed fill color with some opacity stroke_width=stroke_width, stroke_color="rgb(0, 0, 0)", background_color="rgb(0, 0, 0)", background_image=image if image is not None else Image.fromarray(255*np.ones((512,512,3), dtype=np.uint8)), update_streamlit=False, height=image.size[1] if image is not None else 512, width=image.size[0] if image is not None else 512, drawing_mode="freedraw", point_display_radius=0, key="canvas", ) if canvas_result: mask = canvas_result.image_data mask = np.array(mask)[:,:,[3,3,3]] mask = mask > 127 # visualize bdry = cv2.dilate(mask.astype(np.uint8), np.ones((3,3), dtype=np.uint8)) bdry = (bdry > 0) & ~mask masked_image = np.array(image)*(1-mask) + mask*0.3*np.array(image) masked_image[:,:,0][bdry[:,:,0]] = 255 masked_image[:,:,1][bdry[:,:,1]] = 0 masked_image[:,:,2][bdry[:,:,2]] = 0 st.write("Masked Image") st.image(Image.fromarray(masked_image.astype(np.uint8))) prompt = st.text_input("Prompt") scale = float(st.number_input("Guidance", value=10.0)) t_total = int(st.number_input("Diffusion steps", value=50)) if st.button("Sample"): st.text("Sampling") batch_progress = st.progress(0) batch_total = 3 t_progress = st.progress(0) result = st.empty() #canvas = make_canvas(2, 3) def callback(x, batch, t): #result.text(f"{batch}, {t}") batch_progress.progress(min(1.0, (batch+1)/batch_total)) t_progress.progress(min(1.0, (t+1)/t_total)) update_canvas(canvas, x, batch) result.image(canvas) samples = sample( state["model"], prompt, n_runs=3, n_samples=2, H=512, W=512, scale=scale, ddim_steps=t_total, callback=callback, image=np.array(image), mask=np.array(mask), ) st.text("Samples") st.image(samples[0])