From b6cf25386a928fef38913530fb377b4ea47134f1 Mon Sep 17 00:00:00 2001 From: rromb Date: Fri, 10 Jun 2022 10:54:14 +0200 Subject: [PATCH] sampling updates --- ldm/data/coco.py | 9 +++++++-- scripts/txt2img.py | 28 ++++++++++++++++++---------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/ldm/data/coco.py b/ldm/data/coco.py index 7654069..5e5e27e 100644 --- a/ldm/data/coco.py +++ b/ldm/data/coco.py @@ -235,9 +235,14 @@ class CocoImagesAndCaptionsValidation2014(CocoBase): return '2014' if __name__ == '__main__': + with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file: + json_data = json.load(json_file) + capdirs = json_data["annotations"] + import pudb; pudb.set_trace() + #d2 = CocoImagesAndCaptionsTrain2014(size=256) d2 = CocoImagesAndCaptionsValidation2014(size=256) - print("construced val set.") - print(f"length of train split: {len(d2)}") + print("constructed dataset.") + print(f"length of {d2.__class__.__name__}: {len(d2)}") ex2 = d2[0] # ex3 = d3[0] diff --git a/scripts/txt2img.py b/scripts/txt2img.py index ea94f86..0a6ca3e 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -54,6 +54,13 @@ if __name__ == "__main__": help="dir to write results to", default="outputs/txt2img-samples" ) + + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( "--ddim_steps", type=int, @@ -165,7 +172,6 @@ if __name__ == "__main__": base_count = len(os.listdir(sample_path)) grid_count = len(os.listdir(outpath)) - 1 - with torch.no_grad(): with model.ema_scope(): for n in trange(opt.n_iter, desc="Sampling"): @@ -174,6 +180,8 @@ if __name__ == "__main__": uc = None if opt.scale != 1.0: uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) c = model.get_learned_conditioning(prompts) shape = [4, opt.H//8, opt.W//8] samples_ddim, _ = sampler.sample(S=opt.ddim_steps, @@ -195,15 +203,15 @@ if __name__ == "__main__": base_count += 1 all_samples.append(x_samples_ddim) + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=opt.n_samples) - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=opt.n_samples) - - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")