sampling updates

This commit is contained in:
rromb 2022-06-10 10:54:14 +02:00
parent 7ee2fc4310
commit b6cf25386a
2 changed files with 25 additions and 12 deletions

View file

@ -235,9 +235,14 @@ class CocoImagesAndCaptionsValidation2014(CocoBase):
return '2014' return '2014'
if __name__ == '__main__': if __name__ == '__main__':
with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file:
json_data = json.load(json_file)
capdirs = json_data["annotations"]
import pudb; pudb.set_trace()
#d2 = CocoImagesAndCaptionsTrain2014(size=256)
d2 = CocoImagesAndCaptionsValidation2014(size=256) d2 = CocoImagesAndCaptionsValidation2014(size=256)
print("construced val set.") print("constructed dataset.")
print(f"length of train split: {len(d2)}") print(f"length of {d2.__class__.__name__}: {len(d2)}")
ex2 = d2[0] ex2 = d2[0]
# ex3 = d3[0] # ex3 = d3[0]

View file

@ -54,6 +54,13 @@ if __name__ == "__main__":
help="dir to write results to", help="dir to write results to",
default="outputs/txt2img-samples" default="outputs/txt2img-samples"
) )
parser.add_argument(
"--skip_grid",
action='store_true',
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument( parser.add_argument(
"--ddim_steps", "--ddim_steps",
type=int, type=int,
@ -165,7 +172,6 @@ if __name__ == "__main__":
base_count = len(os.listdir(sample_path)) base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
with torch.no_grad(): with torch.no_grad():
with model.ema_scope(): with model.ema_scope():
for n in trange(opt.n_iter, desc="Sampling"): for n in trange(opt.n_iter, desc="Sampling"):
@ -174,6 +180,8 @@ if __name__ == "__main__":
uc = None uc = None
if opt.scale != 1.0: if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""]) uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
shape = [4, opt.H//8, opt.W//8] shape = [4, opt.H//8, opt.W//8]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps, samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
@ -195,15 +203,15 @@ if __name__ == "__main__":
base_count += 1 base_count += 1
all_samples.append(x_samples_ddim) all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=opt.n_samples)
# additionally, save as grid # to image
grid = torch.stack(all_samples, 0) grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
grid = rearrange(grid, 'n b c h w -> (n b) c h w') Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid = make_grid(grid, nrow=opt.n_samples) grid_count += 1
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")