better init image handling

This commit is contained in:
Robin Rombach 2022-08-01 00:14:18 +02:00
parent 4358de0a9a
commit a416813c32
1 changed files with 57 additions and 54 deletions

View File

@ -1,6 +1,7 @@
"""make variations of input image""" """make variations of input image"""
import argparse, os, sys, glob import argparse, os, sys, glob
import PIL
import torch import torch
import numpy as np import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -9,6 +10,8 @@ from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange, repeat
from torchvision.utils import make_grid from torchvision.utils import make_grid
from torch import autocast
from contextlib import nullcontext
import time import time
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
@ -43,8 +46,12 @@ def load_model_from_config(config, ckpt, verbose=False):
def load_img(path): def load_img(path):
image = np.array(Image.open(path).convert("RGB")) image = Image.open(path).convert("RGB")
image = image.astype(np.float32) / 255.0 w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
return 2.*image - 1. return 2.*image - 1.
@ -119,20 +126,6 @@ def main():
help="sample this often", help="sample this often",
) )
parser.add_argument(
"--H",
type=int,
default=256,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=256,
help="image width, in pixel space",
)
parser.add_argument( parser.add_argument(
"--C", "--C",
type=int, type=int,
@ -149,7 +142,7 @@ def main():
parser.add_argument( parser.add_argument(
"--n_samples", "--n_samples",
type=int, type=int,
default=8, default=2,
help="how many samples to produce for each given prompt. A.k.a batch size", help="how many samples to produce for each given prompt. A.k.a batch size",
) )
@ -170,7 +163,7 @@ def main():
parser.add_argument( parser.add_argument(
"--strength", "--strength",
type=float, type=float,
default=0.3, default=0.75,
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
) )
@ -197,6 +190,14 @@ def main():
default=42, default=42,
help="the seed (for reproducible sampling)", help="the seed (for reproducible sampling)",
) )
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
)
opt = parser.parse_args() opt = parser.parse_args()
seed_everything(opt.seed) seed_everything(opt.seed)
@ -244,51 +245,53 @@ def main():
t_enc = int(opt.strength * opt.ddim_steps) t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps") print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad(): with torch.no_grad():
with model.ema_scope(): with precision_scope("cuda"):
tic = time.time() with model.ema_scope():
all_samples = list() tic = time.time()
for n in trange(opt.n_iter, desc="Sampling"): all_samples = list()
for prompts in tqdm(data, desc="data"): for n in trange(opt.n_iter, desc="Sampling"):
uc = None for prompts in tqdm(data, desc="data"):
if opt.scale != 1.0: uc = None
uc = model.get_learned_conditioning(batch_size * [""]) if opt.scale != 1.0:
if isinstance(prompts, tuple): uc = model.get_learned_conditioning(batch_size * [""])
prompts = list(prompts) if isinstance(prompts, tuple):
c = model.get_learned_conditioning(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# encode (scaled latent) # encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
# decode it # decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,) unconditional_conditioning=uc,)
x_samples = model.decode_first_stage(samples) x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save: if not opt.skip_save:
for x_sample in x_samples: for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save( Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png")) os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1 base_count += 1
all_samples.append(x_samples) all_samples.append(x_samples)
if not opt.skip_grid: if not opt.skip_grid:
# additionally, save as grid # additionally, save as grid
grid = torch.stack(all_samples, 0) grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w') grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows) grid = make_grid(grid, nrow=n_rows)
# to image # to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1 grid_count += 1
toc = time.time() toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f"Sampling took {toc - tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec." f"Sampling took {toc - tic}s, i.e., produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
f" \nEnjoy.") f" \nEnjoy.")