allow n row specification when sampling with small batch size

This commit is contained in:
Patrick Esser 2022-06-11 18:35:09 -04:00
parent d06c2277b0
commit a66b27b149

View file

@ -21,6 +21,8 @@ def chunk(it, size):
def load_model_from_config(config, ckpt, verbose=False): def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
@ -108,6 +110,13 @@ if __name__ == "__main__":
help="how many samples to produce for each given prompt. A.k.a batch size", help="how many samples to produce for each given prompt. A.k.a batch size",
) )
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument( parser.add_argument(
"--scale", "--scale",
type=float, type=float,
@ -156,6 +165,7 @@ if __name__ == "__main__":
outpath = opt.outdir outpath = opt.outdir
batch_size = opt.n_samples batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file: if not opt.from_file:
prompt = opt.prompt prompt = opt.prompt
assert prompt is not None assert prompt is not None
@ -207,7 +217,7 @@ if __name__ == "__main__":
# additionally, save as grid # additionally, save as grid
grid = torch.stack(all_samples, 0) grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w') grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=opt.n_samples) grid = make_grid(grid, nrow=n_rows)
# to image # to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()