From 62308078cf511aa5ac8083453522764c65eb4c03 Mon Sep 17 00:00:00 2001 From: Robin Rombach Date: Fri, 29 Jul 2022 21:34:02 +0200 Subject: [PATCH] autocast --- scripts/txt2img.py | 93 ++++++++++++++++++++++++++-------------------- 1 file changed, 52 insertions(+), 41 deletions(-) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 37797ac..ef52ee0 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -9,6 +9,8 @@ from einops import rearrange from torchvision.utils import make_grid import time from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler @@ -178,6 +180,13 @@ def main(): default=42, help="the seed (for reproducible sampling)", ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) opt = parser.parse_args() seed_everything(opt.seed) @@ -217,53 +226,55 @@ def main(): if opt.fixed_code: start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + precision_scope = autocast if opt.precision=="autocast" else nullcontext with torch.no_grad(): - with model.ema_scope(): - tic = time.time() - all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - dynamic_threshold=opt.dyn, - x_T=start_code) + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + dynamic_threshold=opt.dyn, + x_T=start_code) - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - if not opt.skip_save: - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 - all_samples.append(x_samples_ddim) + if not opt.skip_save: + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + all_samples.append(x_samples_ddim) - if not opt.skip_grid: - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 - toc = time.time() + toc = time.time() print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f"Sampling took {toc - tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."