From af4db5dafc622f14fc8247e457c622f0efdd1d02 Mon Sep 17 00:00:00 2001 From: rromb Date: Fri, 15 Jul 2022 13:40:46 +0200 Subject: [PATCH] EMAWithWings (https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298) --- ldm/util.py | 113 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/ldm/util.py b/ldm/util.py index 51839cb..8c09ca1 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -1,6 +1,7 @@ import importlib import torch +from torch import optim import numpy as np from inspect import isfunction @@ -83,4 +84,114 @@ def get_obj_from_str(string, reload=False): if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) \ No newline at end of file + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file