initial code to try and resume larger models from smaller models

This commit is contained in:
Patrick Esser 2022-06-11 18:35:49 -04:00
parent a66b27b149
commit 250f89c1a3

View file

@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat from einops import rearrange, repeat
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from functools import partial from functools import partial
import itertools
from tqdm import tqdm from tqdm import tqdm
from torchvision.utils import make_grid from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities.distributed import rank_zero_only
@ -72,6 +73,7 @@ class DDPM(pl.LightningModule):
use_positional_encodings=False, use_positional_encodings=False,
learn_logvar=False, learn_logvar=False,
logvar_init=0., logvar_init=0.,
make_it_fit=False,
): ):
super().__init__() super().__init__()
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
@ -101,6 +103,7 @@ class DDPM(pl.LightningModule):
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
self.make_it_fit = make_it_fit
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
@ -184,6 +187,7 @@ class DDPM(pl.LightningModule):
if context is not None: if context is not None:
print(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu") sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()): if "state_dict" in list(sd.keys()):
@ -194,6 +198,50 @@ class DDPM(pl.LightningModule):
if k.startswith(ik): if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k)) print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
if self.make_it_fit:
n_params = len([name for name, _ in
itertools.chain(self.named_parameters(),
self.named_buffers())])
for name, param in tqdm(
itertools.chain(self.named_parameters(),
self.named_buffers()),
desc="Fitting old weights to new weights",
total=n_params
):
if not name in sd:
continue
old_shape = sd[name].shape
new_shape = param.shape
assert len(old_shape)==len(new_shape)
if len(new_shape) > 2:
# we only modify first two axes
assert new_shape[2:] == old_shape[2:]
# assumes first axis corresponds to output dim
if not new_shape == old_shape:
new_param = param.clone()
old_param = sd[name]
if len(new_shape) == 1:
for i in range(new_param.shape[0]):
new_param[i] = old_param[i % old_shape[0]]
elif len(new_shape) >= 2:
for i in range(new_param.shape[0]):
for j in range(new_param.shape[1]):
new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
n_used_old = torch.ones(old_shape[1])
for j in range(new_param.shape[1]):
n_used_old[j % old_shape[1]] += 1
n_used_new = torch.zeros(new_shape[1])
for j in range(new_param.shape[1]):
n_used_new[j] = n_used_old[j % old_shape[1]]
n_used_new = n_used_new[None, :]
while len(n_used_new.shape) < len(new_shape):
n_used_new = n_used_new.unsqueeze(-1)
new_param /= n_used_new
sd[name] = new_param
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False) sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")