better init image handling

This commit is contained in:
Robin Rombach 2022-08-01 00:14:18 +02:00
parent 4358de0a9a
commit a416813c32

View file

@ -1,6 +1,7 @@
"""make variations of input image""" """make variations of input image"""
import argparse, os, sys, glob import argparse, os, sys, glob
import PIL
import torch import torch
import numpy as np import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -9,6 +10,8 @@ from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange, repeat
from torchvision.utils import make_grid from torchvision.utils import make_grid
from torch import autocast
from contextlib import nullcontext
import time import time
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
@ -43,8 +46,12 @@ def load_model_from_config(config, ckpt, verbose=False):
def load_img(path): def load_img(path):
image = np.array(Image.open(path).convert("RGB")) image = Image.open(path).convert("RGB")
image = image.astype(np.float32) / 255.0 w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
return 2.*image - 1. return 2.*image - 1.
@ -119,20 +126,6 @@ def main():
help="sample this often", help="sample this often",
) )
parser.add_argument(
"--H",
type=int,
default=256,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=256,
help="image width, in pixel space",
)
parser.add_argument( parser.add_argument(
"--C", "--C",
type=int, type=int,
@ -149,7 +142,7 @@ def main():
parser.add_argument( parser.add_argument(
"--n_samples", "--n_samples",
type=int, type=int,
default=8, default=2,
help="how many samples to produce for each given prompt. A.k.a batch size", help="how many samples to produce for each given prompt. A.k.a batch size",
) )
@ -170,7 +163,7 @@ def main():
parser.add_argument( parser.add_argument(
"--strength", "--strength",
type=float, type=float,
default=0.3, default=0.75,
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
) )
@ -197,6 +190,14 @@ def main():
default=42, default=42,
help="the seed (for reproducible sampling)", help="the seed (for reproducible sampling)",
) )
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
)
opt = parser.parse_args() opt = parser.parse_args()
seed_everything(opt.seed) seed_everything(opt.seed)
@ -244,7 +245,9 @@ def main():
t_enc = int(opt.strength * opt.ddim_steps) t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps") print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad(): with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope(): with model.ema_scope():
tic = time.time() tic = time.time()
all_samples = list() all_samples = list()
@ -288,7 +291,7 @@ def main():
toc = time.time() toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f"Sampling took {toc - tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec." f"Sampling took {toc - tic}s, i.e., produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
f" \nEnjoy.") f" \nEnjoy.")