Merge remote-tracking branch 'origin/main'

This commit is contained in:
Patrick Esser 2022-07-31 23:27:09 +00:00
commit 85868a5d34
4 changed files with 131 additions and 99 deletions

View File

@ -1,6 +1,7 @@
"""make variations of input image""" """make variations of input image"""
import argparse, os, sys, glob import argparse, os, sys, glob
import PIL
import torch import torch
import numpy as np import numpy as np
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -9,6 +10,8 @@ from tqdm import tqdm, trange
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange, repeat
from torchvision.utils import make_grid from torchvision.utils import make_grid
from torch import autocast
from contextlib import nullcontext
import time import time
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
@ -43,8 +46,12 @@ def load_model_from_config(config, ckpt, verbose=False):
def load_img(path): def load_img(path):
image = np.array(Image.open(path).convert("RGB")) image = Image.open(path).convert("RGB")
image = image.astype(np.float32) / 255.0 w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2) image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image) image = torch.from_numpy(image)
return 2.*image - 1. return 2.*image - 1.
@ -119,20 +126,6 @@ def main():
help="sample this often", help="sample this often",
) )
parser.add_argument(
"--H",
type=int,
default=256,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=256,
help="image width, in pixel space",
)
parser.add_argument( parser.add_argument(
"--C", "--C",
type=int, type=int,
@ -149,7 +142,7 @@ def main():
parser.add_argument( parser.add_argument(
"--n_samples", "--n_samples",
type=int, type=int,
default=8, default=2,
help="how many samples to produce for each given prompt. A.k.a batch size", help="how many samples to produce for each given prompt. A.k.a batch size",
) )
@ -170,7 +163,7 @@ def main():
parser.add_argument( parser.add_argument(
"--strength", "--strength",
type=float, type=float,
default=0.3, default=0.75,
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image", help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
) )
@ -197,6 +190,14 @@ def main():
default=42, default=42,
help="the seed (for reproducible sampling)", help="the seed (for reproducible sampling)",
) )
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
)
opt = parser.parse_args() opt = parser.parse_args()
seed_everything(opt.seed) seed_everything(opt.seed)
@ -244,51 +245,53 @@ def main():
t_enc = int(opt.strength * opt.ddim_steps) t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps") print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad(): with torch.no_grad():
with model.ema_scope(): with precision_scope("cuda"):
tic = time.time() with model.ema_scope():
all_samples = list() tic = time.time()
for n in trange(opt.n_iter, desc="Sampling"): all_samples = list()
for prompts in tqdm(data, desc="data"): for n in trange(opt.n_iter, desc="Sampling"):
uc = None for prompts in tqdm(data, desc="data"):
if opt.scale != 1.0: uc = None
uc = model.get_learned_conditioning(batch_size * [""]) if opt.scale != 1.0:
if isinstance(prompts, tuple): uc = model.get_learned_conditioning(batch_size * [""])
prompts = list(prompts) if isinstance(prompts, tuple):
c = model.get_learned_conditioning(prompts) prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# encode (scaled latent) # encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
# decode it # decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,) unconditional_conditioning=uc,)
x_samples = model.decode_first_stage(samples) x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save: if not opt.skip_save:
for x_sample in x_samples: for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save( Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png")) os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1 base_count += 1
all_samples.append(x_samples) all_samples.append(x_samples)
if not opt.skip_grid: if not opt.skip_grid:
# additionally, save as grid # additionally, save as grid
grid = torch.stack(all_samples, 0) grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w') grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows) grid = make_grid(grid, nrow=n_rows)
# to image # to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1 grid_count += 1
toc = time.time() toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f"Sampling took {toc - tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec." f"Sampling took {toc - tic}s, i.e., produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
f" \nEnjoy.") f" \nEnjoy.")

View File

@ -3,7 +3,7 @@ import torch
import fire import fire
def prune_it(p): def prune_it(p, keep_only_ema=False):
print(f"prunin' in path: {p}") print(f"prunin' in path: {p}")
size_initial = os.path.getsize(p) size_initial = os.path.getsize(p)
nsd = dict() nsd = dict()
@ -16,12 +16,30 @@ def prune_it(p):
print(f"removing optimizer states for path {p}") print(f"removing optimizer states for path {p}")
if "global_step" in sd: if "global_step" in sd:
print(f"This is global step {sd['global_step']}.") print(f"This is global step {sd['global_step']}.")
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if keep_only_ema:
sd = nsd["state_dict"].copy()
# infer ema keys
ema_keys = {k: "model_ema." + k[6:].replace(".", "") for k in sd.keys() if k.startswith("model.")}
new_sd = dict()
for k in sd:
if k in ema_keys:
new_sd[k] = sd[ema_keys[k]]
elif not k.startswith("model_ema.") or k in ["model_ema.num_updates", "model_ema.decay"]:
new_sd[k] = sd[k]
assert len(new_sd) == len(sd) - len(ema_keys)
nsd["state_dict"] = new_sd
fn = f"{os.path.splitext(p)[0]}-pruned.ckpt" if not keep_only_ema else f"{os.path.splitext(p)[0]}-ema-pruned.ckpt"
print(f"saving pruned checkpoint at: {fn}") print(f"saving pruned checkpoint at: {fn}")
torch.save(nsd, fn) torch.save(nsd, fn)
newsize = os.path.getsize(fn) newsize = os.path.getsize(fn)
print(f"New ckpt size: {newsize*1e-9:.2f} GB. " MSG = f"New ckpt size: {newsize*1e-9:.2f} GB. " + \
f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states") f"Saved {(size_initial - newsize)*1e-9:.2f} GB by removing optimizer states"
if keep_only_ema:
MSG += " and non-EMA weights"
print(MSG)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -9,6 +9,8 @@ from einops import rearrange
from torchvision.utils import make_grid from torchvision.utils import make_grid
import time import time
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
@ -178,6 +180,13 @@ def main():
default=42, default=42,
help="the seed (for reproducible sampling)", help="the seed (for reproducible sampling)",
) )
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
)
opt = parser.parse_args() opt = parser.parse_args()
seed_everything(opt.seed) seed_everything(opt.seed)
@ -217,53 +226,55 @@ def main():
if opt.fixed_code: if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
with torch.no_grad(): with torch.no_grad():
with model.ema_scope(): with precision_scope("cuda"):
tic = time.time() with model.ema_scope():
all_samples = list() tic = time.time()
for n in trange(opt.n_iter, desc="Sampling"): all_samples = list()
for prompts in tqdm(data, desc="data"): for n in trange(opt.n_iter, desc="Sampling"):
uc = None for prompts in tqdm(data, desc="data"):
if opt.scale != 1.0: uc = None
uc = model.get_learned_conditioning(batch_size * [""]) if opt.scale != 1.0:
if isinstance(prompts, tuple): uc = model.get_learned_conditioning(batch_size * [""])
prompts = list(prompts) if isinstance(prompts, tuple):
c = model.get_learned_conditioning(prompts) prompts = list(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f] c = model.get_learned_conditioning(prompts)
samples_ddim, _ = sampler.sample(S=opt.ddim_steps, shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
conditioning=c, samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
batch_size=opt.n_samples, conditioning=c,
shape=shape, batch_size=opt.n_samples,
verbose=False, shape=shape,
unconditional_guidance_scale=opt.scale, verbose=False,
unconditional_conditioning=uc, unconditional_guidance_scale=opt.scale,
eta=opt.ddim_eta, unconditional_conditioning=uc,
dynamic_threshold=opt.dyn, eta=opt.ddim_eta,
x_T=start_code) dynamic_threshold=opt.dyn,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save: if not opt.skip_save:
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save( Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png")) os.path.join(sample_path, f"{base_count:05}.png"))
base_count += 1 base_count += 1
all_samples.append(x_samples_ddim) all_samples.append(x_samples_ddim)
if not opt.skip_grid: if not opt.skip_grid:
# additionally, save as grid # additionally, save as grid
grid = torch.stack(all_samples, 0) grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w') grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows) grid = make_grid(grid, nrow=n_rows)
# to image # to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid_count += 1 grid_count += 1
toc = time.time() toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f"Sampling took {toc - tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec." f"Sampling took {toc - tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."