sd inpainting eval script and demo
This commit is contained in:
parent
81d560ef1f
commit
693e713c3e
2 changed files with 201 additions and 4 deletions
|
@ -114,20 +114,22 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def run(
|
||||||
|
config="/fsx/stable-diffusion/stable-diffusion/configs/stable-diffusion/inpainting/v1-finetune-for-inpainting-laion-iaesthe.yaml",
|
||||||
|
#ckpt="/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-07-28T07-44-05_v1-finetune-for-inpainting-laion-aesthetic-larger-masks/checkpoints/last.ckpt",
|
||||||
|
ckpt="/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-08-01T08-52-14_v1-finetune-for-inpainting-laion-aesthetic-larger-masks-and-ucfg/checkpoints/last.ckpt",
|
||||||
|
):
|
||||||
st.title("Stable Inpainting")
|
st.title("Stable Inpainting")
|
||||||
state = init()
|
state = init()
|
||||||
|
|
||||||
if not "model" in state:
|
if not "model" in state:
|
||||||
config = "/fsx/stable-diffusion/stable-diffusion/configs/stable-diffusion/inpainting/v1-finetune-for-inpainting-laion-iaesthe.yaml"
|
|
||||||
ckpt = "/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-07-24T16-01-27_v1-finetune-for-inpainting-laion-iaesthe/checkpoints/last.ckpt"
|
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
model = load_model_from_config(config, ckpt)
|
model = load_model_from_config(config, ckpt)
|
||||||
state["model"] = model
|
state["model"] = model
|
||||||
|
|
||||||
uploaded_file = st.file_uploader("Upload image to inpaint")
|
uploaded_file = st.file_uploader("Upload image to inpaint")
|
||||||
if uploaded_file is not None:
|
if uploaded_file is not None:
|
||||||
image = Image.open(io.BytesIO(uploaded_file.getvalue()))
|
image = Image.open(io.BytesIO(uploaded_file.getvalue())).convert("RGB")
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
smaller = min(width, height)
|
smaller = min(width, height)
|
||||||
crop = (
|
crop = (
|
||||||
|
@ -206,3 +208,8 @@ if __name__ == "__main__":
|
||||||
)
|
)
|
||||||
st.text("Samples")
|
st.text("Samples")
|
||||||
st.image(samples[0])
|
st.image(samples[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import fire
|
||||||
|
fire.Fire(run)
|
||||||
|
|
190
scripts/inpaint_sd.py
Normal file
190
scripts/inpaint_sd.py
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
import argparse, os, sys, glob
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from main import instantiate_from_config
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
|
|
||||||
|
def make_batch_ldm(image, mask, device):
|
||||||
|
image = np.array(Image.open(image).convert("RGB"))
|
||||||
|
image = image.astype(np.float32)/255.0
|
||||||
|
image = image[None].transpose(0,3,1,2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
|
||||||
|
mask = np.array(Image.open(mask).convert("L"))
|
||||||
|
mask = mask.astype(np.float32)/255.0
|
||||||
|
mask = mask[None,None]
|
||||||
|
mask[mask < 0.5] = 0
|
||||||
|
mask[mask >= 0.5] = 1
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
|
||||||
|
masked_image = (1-mask)*image
|
||||||
|
|
||||||
|
batch = {"image": image, "mask": mask, "masked_image": masked_image}
|
||||||
|
for k in batch:
|
||||||
|
batch[k] = batch[k].to(device=device)
|
||||||
|
batch[k] = batch[k]*2.0-1.0
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def make_batch_sd(
|
||||||
|
image,
|
||||||
|
mask,
|
||||||
|
txt,
|
||||||
|
device):
|
||||||
|
# image hwc in -1 1
|
||||||
|
image = np.array(Image.open(image).convert("RGB"))
|
||||||
|
image = image[None].transpose(0,3,1,2)
|
||||||
|
image = torch.from_numpy(image).to(dtype=torch.float32)/127.5-1.0
|
||||||
|
|
||||||
|
mask = np.array(Image.open(mask).convert("L"))
|
||||||
|
mask = mask.astype(np.float32)/255.0
|
||||||
|
mask = mask[None,None]
|
||||||
|
mask[mask < 0.5] = 0
|
||||||
|
mask[mask >= 0.5] = 1
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
|
||||||
|
masked_image = image * (mask < 0.5)
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"jpg": image.to(device=device),
|
||||||
|
"txt": [txt],
|
||||||
|
"mask": mask.to(device=device),
|
||||||
|
"masked_image": masked_image.to(device=device),
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--indir",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--outdir",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="dir to write results to",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--steps",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="number of ddim sampling steps",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eta",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="eta of ddim",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale",
|
||||||
|
type=float,
|
||||||
|
default=6.0,
|
||||||
|
help="scale of unconditional guidance",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--worldsize",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="scale of unconditional guidance",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rank",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="scale of unconditional guidance",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt",
|
||||||
|
type=str,
|
||||||
|
default="/fsx/robin/stable-diffusion/stable-diffusion/logs/2022-08-01T08-52-14_v1-finetune-for-inpainting-laion-aesthetic-larger-masks-and-ucfg/checkpoints/last.ckpt",
|
||||||
|
help="scale of unconditional guidance",
|
||||||
|
)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
assert opt.rank < opt.worldsize
|
||||||
|
|
||||||
|
mstr = "mask000.png"
|
||||||
|
masks = sorted(glob.glob(os.path.join(opt.indir, f"*_{mstr}")))
|
||||||
|
images = [x.replace(f"_{mstr}", ".png") for x in masks]
|
||||||
|
print(f"Found {len(masks)} inputs.")
|
||||||
|
|
||||||
|
#config = "models/ldm/inpainting_big/config.yaml"
|
||||||
|
config="/fsx/stable-diffusion/stable-diffusion/configs/stable-diffusion/inpainting/v1-finetune-for-inpainting-laion-iaesthe.yaml"
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
#ckpt="models/ldm/inpainting_big/last.ckpt"
|
||||||
|
ckpt=opt.ckpt
|
||||||
|
model.load_state_dict(torch.load(ckpt)["state_dict"],
|
||||||
|
strict=False)
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
indices = [i for i in range(len(images)) if i % opt.worldsize == opt.rank]
|
||||||
|
images = [images[i] for i in indices]
|
||||||
|
masks = [masks[i] for i in indices]
|
||||||
|
|
||||||
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
|
with torch.no_grad():
|
||||||
|
with model.ema_scope():
|
||||||
|
for image, mask in tqdm(zip(images, masks), total=len(images)):
|
||||||
|
outpath = os.path.join(opt.outdir, os.path.split(image)[1])
|
||||||
|
#batch = make_batch_ldm(image, mask, device=device)
|
||||||
|
|
||||||
|
##### unroll
|
||||||
|
batch = make_batch_sd(image, mask, txt="photograph of a beautiful empty scene, highest quality settings",
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
c = model.cond_stage_model.encode(batch["txt"])
|
||||||
|
|
||||||
|
c_cat = list()
|
||||||
|
for ck in model.concat_keys:
|
||||||
|
cc = batch[ck].float()
|
||||||
|
if ck != model.masked_image_key:
|
||||||
|
bchw = (1, 4, 64, 64)
|
||||||
|
cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
|
||||||
|
else:
|
||||||
|
cc = model.get_first_stage_encoding(model.encode_first_stage(cc))
|
||||||
|
c_cat.append(cc)
|
||||||
|
c_cat = torch.cat(c_cat, dim=1)
|
||||||
|
|
||||||
|
# cond
|
||||||
|
cond={"c_concat": [c_cat], "c_crossattn": [c]}
|
||||||
|
|
||||||
|
# uncond cond
|
||||||
|
uc_cross = model.get_unconditional_conditioning(1, "")
|
||||||
|
uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
|
||||||
|
|
||||||
|
shape = (model.channels, model.image_size, model.image_size)
|
||||||
|
samples_cfg, intermediates = sampler.sample(
|
||||||
|
opt.steps,
|
||||||
|
1,
|
||||||
|
shape,
|
||||||
|
cond,
|
||||||
|
verbose=False,
|
||||||
|
eta=opt.eta,
|
||||||
|
unconditional_guidance_scale=opt.scale,
|
||||||
|
unconditional_conditioning=uc_full,
|
||||||
|
)
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_cfg)
|
||||||
|
|
||||||
|
image = torch.clamp((batch["jpg"]+1.0)/2.0,
|
||||||
|
min=0.0, max=1.0)
|
||||||
|
mask = torch.clamp((batch["mask"]+1.0)/2.0,
|
||||||
|
min=0.0, max=1.0)
|
||||||
|
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
|
||||||
|
min=0.0, max=1.0)
|
||||||
|
|
||||||
|
inpainted = (1-mask)*image+mask*predicted_image
|
||||||
|
inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
|
||||||
|
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
|
Loading…
Reference in a new issue