sampling updates
This commit is contained in:
parent
7ee2fc4310
commit
b6cf25386a
2 changed files with 25 additions and 12 deletions
|
@ -235,9 +235,14 @@ class CocoImagesAndCaptionsValidation2014(CocoBase):
|
||||||
return '2014'
|
return '2014'
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file:
|
||||||
|
json_data = json.load(json_file)
|
||||||
|
capdirs = json_data["annotations"]
|
||||||
|
import pudb; pudb.set_trace()
|
||||||
|
#d2 = CocoImagesAndCaptionsTrain2014(size=256)
|
||||||
d2 = CocoImagesAndCaptionsValidation2014(size=256)
|
d2 = CocoImagesAndCaptionsValidation2014(size=256)
|
||||||
print("construced val set.")
|
print("constructed dataset.")
|
||||||
print(f"length of train split: {len(d2)}")
|
print(f"length of {d2.__class__.__name__}: {len(d2)}")
|
||||||
|
|
||||||
ex2 = d2[0]
|
ex2 = d2[0]
|
||||||
# ex3 = d3[0]
|
# ex3 = d3[0]
|
||||||
|
|
|
@ -54,6 +54,13 @@ if __name__ == "__main__":
|
||||||
help="dir to write results to",
|
help="dir to write results to",
|
||||||
default="outputs/txt2img-samples"
|
default="outputs/txt2img-samples"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_grid",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ddim_steps",
|
"--ddim_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -165,7 +172,6 @@ if __name__ == "__main__":
|
||||||
base_count = len(os.listdir(sample_path))
|
base_count = len(os.listdir(sample_path))
|
||||||
grid_count = len(os.listdir(outpath)) - 1
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with model.ema_scope():
|
with model.ema_scope():
|
||||||
for n in trange(opt.n_iter, desc="Sampling"):
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
@ -174,6 +180,8 @@ if __name__ == "__main__":
|
||||||
uc = None
|
uc = None
|
||||||
if opt.scale != 1.0:
|
if opt.scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch_size * [""])
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
shape = [4, opt.H//8, opt.W//8]
|
shape = [4, opt.H//8, opt.W//8]
|
||||||
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||||
|
@ -195,7 +203,7 @@ if __name__ == "__main__":
|
||||||
base_count += 1
|
base_count += 1
|
||||||
all_samples.append(x_samples_ddim)
|
all_samples.append(x_samples_ddim)
|
||||||
|
|
||||||
|
if not opt.skip_grid:
|
||||||
# additionally, save as grid
|
# additionally, save as grid
|
||||||
grid = torch.stack(all_samples, 0)
|
grid = torch.stack(all_samples, 0)
|
||||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
|
Loading…
Reference in a new issue