diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 613de5e..ea94f86 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -4,6 +4,7 @@ import numpy as np from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange +from itertools import islice from einops import rearrange from torchvision.utils import make_grid @@ -12,6 +13,11 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") @@ -51,7 +57,7 @@ if __name__ == "__main__": parser.add_argument( "--ddim_steps", type=int, - default=200, + default=50, help="number of ddim sampling steps", ) @@ -91,8 +97,8 @@ if __name__ == "__main__": parser.add_argument( "--n_samples", type=int, - default=4, - help="how many samples to produce for the given prompt", + default=8, + help="how many samples to produce for each given prompt. A.k.a batch size", ) parser.add_argument( @@ -101,11 +107,35 @@ if __name__ == "__main__": default=5.0, help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) + + parser.add_argument( + "--dyn", + type=float, + help="dynamic thresholding from Imagen, in latent space (TODO: try in pixel space with intermediate decode)", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + + parser.add_argument( + "--config", + type=str, + default="logs/f8-kl-clip-encoder-256x256-run1/configs/2022-06-01T22-11-40-project.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="logs/f8-kl-clip-encoder-256x256-run1/checkpoints/last.ckpt", + help="path to checkpoint of model", + ) opt = parser.parse_args() - config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic - model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt") # TODO: check path + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) @@ -118,48 +148,62 @@ if __name__ == "__main__": os.makedirs(opt.outdir, exist_ok=True) outpath = opt.outdir - prompt = opt.prompt + batch_size = opt.n_samples + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + - all_samples=list() with torch.no_grad(): with model.ema_scope(): - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(opt.n_samples * [""]) for n in trange(opt.n_iter, desc="Sampling"): - c = model.get_learned_conditioning(opt.n_samples * [prompt]) - shape = [4, opt.H//8, opt.W//8] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta) + all_samples = list() + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + c = model.get_learned_conditioning(prompts) + shape = [4, opt.H//8, opt.W//8] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + dynamic_threshold=opt.dyn) - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png")) - base_count += 1 - all_samples.append(x_samples_ddim) + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=opt.n_samples) + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=opt.n_samples) - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png')) + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 - print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.") + print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")