From 250f89c1a3ab022a3a875fc708f46d2e66f38426 Mon Sep 17 00:00:00 2001 From: Patrick Esser Date: Sat, 11 Jun 2022 18:35:49 -0400 Subject: [PATCH] initial code to try and resume larger models from smaller models --- ldm/models/diffusion/ddpm.py | 48 ++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 85c2feb..9ca6ff6 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import LambdaLR from einops import rearrange, repeat from contextlib import contextmanager, nullcontext from functools import partial +import itertools from tqdm import tqdm from torchvision.utils import make_grid from pytorch_lightning.utilities.distributed import rank_zero_only @@ -72,6 +73,7 @@ class DDPM(pl.LightningModule): use_positional_encodings=False, learn_logvar=False, logvar_init=0., + make_it_fit=False, ): super().__init__() assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' @@ -101,6 +103,7 @@ class DDPM(pl.LightningModule): if monitor is not None: self.monitor = monitor + self.make_it_fit = make_it_fit if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) @@ -184,6 +187,7 @@ class DDPM(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") + @torch.no_grad() def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): @@ -194,6 +198,50 @@ class DDPM(pl.LightningModule): if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape)==len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")