eval script

This commit is contained in:
rromb 2022-06-05 19:22:43 +02:00
parent cec5968820
commit 36b5177221
1 changed files with 79 additions and 35 deletions

View File

@ -4,6 +4,7 @@ import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from tqdm import tqdm, trange from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange from einops import rearrange
from torchvision.utils import make_grid from torchvision.utils import make_grid
@ -12,6 +13,11 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False): def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
@ -51,7 +57,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--ddim_steps", "--ddim_steps",
type=int, type=int,
default=200, default=50,
help="number of ddim sampling steps", help="number of ddim sampling steps",
) )
@ -91,8 +97,8 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--n_samples", "--n_samples",
type=int, type=int,
default=4, default=8,
help="how many samples to produce for the given prompt", help="how many samples to produce for each given prompt. A.k.a batch size",
) )
parser.add_argument( parser.add_argument(
@ -101,11 +107,35 @@ if __name__ == "__main__":
default=5.0, default=5.0,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
) )
parser.add_argument(
"--dyn",
type=float,
help="dynamic thresholding from Imagen, in latent space (TODO: try in pixel space with intermediate decode)",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="logs/f8-kl-clip-encoder-256x256-run1/configs/2022-06-01T22-11-40-project.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt",
help="path to checkpoint of model",
)
opt = parser.parse_args() opt = parser.parse_args()
config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt") # TODO: check path model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device) model = model.to(device)
@ -118,48 +148,62 @@ if __name__ == "__main__":
os.makedirs(opt.outdir, exist_ok=True) os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir outpath = opt.outdir
prompt = opt.prompt batch_size = opt.n_samples
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples") sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True) os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path)) base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
all_samples=list()
with torch.no_grad(): with torch.no_grad():
with model.ema_scope(): with model.ema_scope():
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(opt.n_samples * [""])
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling"):
c = model.get_learned_conditioning(opt.n_samples * [prompt]) all_samples = list()
shape = [4, opt.H//8, opt.W//8] for prompts in tqdm(data, desc="data"):
samples_ddim, _ = sampler.sample(S=opt.ddim_steps, uc = None
conditioning=c, if opt.scale != 1.0:
batch_size=opt.n_samples, uc = model.get_learned_conditioning(batch_size * [""])
shape=shape, c = model.get_learned_conditioning(prompts)
verbose=False, shape = [4, opt.H//8, opt.W//8]
unconditional_guidance_scale=opt.scale, samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
unconditional_conditioning=uc, conditioning=c,
eta=opt.ddim_eta) batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
dynamic_threshold=opt.dyn)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png")) Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1 base_count += 1
all_samples.append(x_samples_ddim) all_samples.append(x_samples_ddim)
# additionally, save as grid # additionally, save as grid
grid = torch.stack(all_samples, 0) grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w') grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=opt.n_samples) grid = make_grid(grid, nrow=opt.n_samples)
# to image # to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png')) Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.") print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")