allow n row specification when sampling with small batch size
This commit is contained in:
parent
d06c2277b0
commit
a66b27b149
1 changed files with 11 additions and 1 deletions
|
@ -21,6 +21,8 @@ def chunk(it, size):
|
|||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
@ -108,6 +110,13 @@ if __name__ == "__main__":
|
|||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--n_rows",
|
||||
type=int,
|
||||
default=0,
|
||||
help="rows in the grid (default: n_samples)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
type=float,
|
||||
|
@ -156,6 +165,7 @@ if __name__ == "__main__":
|
|||
outpath = opt.outdir
|
||||
|
||||
batch_size = opt.n_samples
|
||||
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||
if not opt.from_file:
|
||||
prompt = opt.prompt
|
||||
assert prompt is not None
|
||||
|
@ -207,7 +217,7 @@ if __name__ == "__main__":
|
|||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
grid = make_grid(grid, nrow=opt.n_samples)
|
||||
grid = make_grid(grid, nrow=n_rows)
|
||||
|
||||
# to image
|
||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||
|
|
Loading…
Reference in a new issue