allow n row specification when sampling with small batch size
This commit is contained in:
parent
d06c2277b0
commit
a66b27b149
1 changed files with 11 additions and 1 deletions
|
@ -21,6 +21,8 @@ def chunk(it, size):
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
print(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
m, u = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
@ -108,6 +110,13 @@ if __name__ == "__main__":
|
||||||
help="how many samples to produce for each given prompt. A.k.a batch size",
|
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_rows",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="rows in the grid (default: n_samples)",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--scale",
|
"--scale",
|
||||||
type=float,
|
type=float,
|
||||||
|
@ -156,6 +165,7 @@ if __name__ == "__main__":
|
||||||
outpath = opt.outdir
|
outpath = opt.outdir
|
||||||
|
|
||||||
batch_size = opt.n_samples
|
batch_size = opt.n_samples
|
||||||
|
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||||
if not opt.from_file:
|
if not opt.from_file:
|
||||||
prompt = opt.prompt
|
prompt = opt.prompt
|
||||||
assert prompt is not None
|
assert prompt is not None
|
||||||
|
@ -207,7 +217,7 @@ if __name__ == "__main__":
|
||||||
# additionally, save as grid
|
# additionally, save as grid
|
||||||
grid = torch.stack(all_samples, 0)
|
grid = torch.stack(all_samples, 0)
|
||||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
grid = make_grid(grid, nrow=opt.n_samples)
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
# to image
|
# to image
|
||||||
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
|
Loading…
Reference in a new issue