From a66b27b1493321df7e0ac3bd440448bcf4b7b609 Mon Sep 17 00:00:00 2001 From: Patrick Esser Date: Sat, 11 Jun 2022 18:35:09 -0400 Subject: [PATCH] allow n row specification when sampling with small batch size --- scripts/txt2img.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 0a6ca3e..d9ec628 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -21,6 +21,8 @@ def chunk(it, size): def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) @@ -108,6 +110,13 @@ if __name__ == "__main__": help="how many samples to produce for each given prompt. A.k.a batch size", ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( "--scale", type=float, @@ -156,6 +165,7 @@ if __name__ == "__main__": outpath = opt.outdir batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size if not opt.from_file: prompt = opt.prompt assert prompt is not None @@ -207,7 +217,7 @@ if __name__ == "__main__": # additionally, save as grid grid = torch.stack(all_samples, 0) grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=opt.n_samples) + grid = make_grid(grid, nrow=n_rows) # to image grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()