From af4db5dafc622f14fc8247e457c622f0efdd1d02 Mon Sep 17 00:00:00 2001 From: rromb Date: Fri, 15 Jul 2022 13:40:46 +0200 Subject: [PATCH 1/2] EMAWithWings (https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298) --- ldm/util.py | 113 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/ldm/util.py b/ldm/util.py index 51839cb..8c09ca1 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -1,6 +1,7 @@ import importlib import torch +from torch import optim import numpy as np from inspect import isfunction @@ -83,4 +84,114 @@ def get_obj_from_str(string, reload=False): if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) \ No newline at end of file + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file From 37e59ee4878f81bcebfadfb5991405a67a0919e4 Mon Sep 17 00:00:00 2001 From: rromb Date: Fri, 15 Jul 2022 13:50:15 +0200 Subject: [PATCH 2/2] optionally fix start code for sampling --- scripts/txt2img.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 6e98f83..37797ac 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -83,6 +83,11 @@ def main(): action='store_true', help="use plms sampling", ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across all samples ", + ) parser.add_argument( "--ddim_eta", @@ -155,7 +160,6 @@ def main(): type=str, help="if specified, load prompts from this file", ) - parser.add_argument( "--config", type=str, @@ -209,6 +213,10 @@ def main(): base_count = len(os.listdir(sample_path)) grid_count = len(os.listdir(outpath)) - 1 + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + with torch.no_grad(): with model.ema_scope(): tic = time.time() @@ -221,7 +229,7 @@ def main(): if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) - shape = [opt.C, opt.H//opt.f, opt.W//opt.f] + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] samples_ddim, _ = sampler.sample(S=opt.ddim_steps, conditioning=c, batch_size=opt.n_samples, @@ -230,15 +238,17 @@ def main(): unconditional_guidance_scale=opt.scale, unconditional_conditioning=uc, eta=opt.ddim_eta, - dynamic_threshold=opt.dyn) + dynamic_threshold=opt.dyn, + x_T=start_code) x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) if not opt.skip_save: for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:05}.png")) + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.png")) base_count += 1 all_samples.append(x_samples_ddim) @@ -256,7 +266,7 @@ def main(): toc = time.time() print(f"Your samples are ready and waiting for you here: \n{outpath} \n" - f"Sampling took {toc-tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec." + f"Sampling took {toc - tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec." f" \nEnjoy.")