sampling updates
This commit is contained in:
parent
7ee2fc4310
commit
b6cf25386a
2 changed files with 25 additions and 12 deletions
|
@ -235,9 +235,14 @@ class CocoImagesAndCaptionsValidation2014(CocoBase):
|
|||
return '2014'
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file:
|
||||
json_data = json.load(json_file)
|
||||
capdirs = json_data["annotations"]
|
||||
import pudb; pudb.set_trace()
|
||||
#d2 = CocoImagesAndCaptionsTrain2014(size=256)
|
||||
d2 = CocoImagesAndCaptionsValidation2014(size=256)
|
||||
print("construced val set.")
|
||||
print(f"length of train split: {len(d2)}")
|
||||
print("constructed dataset.")
|
||||
print(f"length of {d2.__class__.__name__}: {len(d2)}")
|
||||
|
||||
ex2 = d2[0]
|
||||
# ex3 = d3[0]
|
||||
|
|
|
@ -54,6 +54,13 @@ if __name__ == "__main__":
|
|||
help="dir to write results to",
|
||||
default="outputs/txt2img-samples"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--skip_grid",
|
||||
action='store_true',
|
||||
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
|
@ -165,7 +172,6 @@ if __name__ == "__main__":
|
|||
base_count = len(os.listdir(sample_path))
|
||||
grid_count = len(os.listdir(outpath)) - 1
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
with model.ema_scope():
|
||||
for n in trange(opt.n_iter, desc="Sampling"):
|
||||
|
@ -174,6 +180,8 @@ if __name__ == "__main__":
|
|||
uc = None
|
||||
if opt.scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
shape = [4, opt.H//8, opt.W//8]
|
||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||
|
@ -195,7 +203,7 @@ if __name__ == "__main__":
|
|||
base_count += 1
|
||||
all_samples.append(x_samples_ddim)
|
||||
|
||||
|
||||
if not opt.skip_grid:
|
||||
# additionally, save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
|
|
Loading…
Reference in a new issue