initial code to try and resume larger models from smaller models
This commit is contained in:
parent
a66b27b149
commit
250f89c1a3
1 changed files with 48 additions and 0 deletions
|
@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import LambdaLR
|
|||
from einops import rearrange, repeat
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import partial
|
||||
import itertools
|
||||
from tqdm import tqdm
|
||||
from torchvision.utils import make_grid
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
@ -72,6 +73,7 @@ class DDPM(pl.LightningModule):
|
|||
use_positional_encodings=False,
|
||||
learn_logvar=False,
|
||||
logvar_init=0.,
|
||||
make_it_fit=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
||||
|
@ -101,6 +103,7 @@ class DDPM(pl.LightningModule):
|
|||
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.make_it_fit = make_it_fit
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
||||
|
||||
|
@ -184,6 +187,7 @@ class DDPM(pl.LightningModule):
|
|||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
@torch.no_grad()
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
|
@ -194,6 +198,50 @@ class DDPM(pl.LightningModule):
|
|||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
if self.make_it_fit:
|
||||
n_params = len([name for name, _ in
|
||||
itertools.chain(self.named_parameters(),
|
||||
self.named_buffers())])
|
||||
for name, param in tqdm(
|
||||
itertools.chain(self.named_parameters(),
|
||||
self.named_buffers()),
|
||||
desc="Fitting old weights to new weights",
|
||||
total=n_params
|
||||
):
|
||||
if not name in sd:
|
||||
continue
|
||||
old_shape = sd[name].shape
|
||||
new_shape = param.shape
|
||||
assert len(old_shape)==len(new_shape)
|
||||
if len(new_shape) > 2:
|
||||
# we only modify first two axes
|
||||
assert new_shape[2:] == old_shape[2:]
|
||||
# assumes first axis corresponds to output dim
|
||||
if not new_shape == old_shape:
|
||||
new_param = param.clone()
|
||||
old_param = sd[name]
|
||||
if len(new_shape) == 1:
|
||||
for i in range(new_param.shape[0]):
|
||||
new_param[i] = old_param[i % old_shape[0]]
|
||||
elif len(new_shape) >= 2:
|
||||
for i in range(new_param.shape[0]):
|
||||
for j in range(new_param.shape[1]):
|
||||
new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
|
||||
|
||||
n_used_old = torch.ones(old_shape[1])
|
||||
for j in range(new_param.shape[1]):
|
||||
n_used_old[j % old_shape[1]] += 1
|
||||
n_used_new = torch.zeros(new_shape[1])
|
||||
for j in range(new_param.shape[1]):
|
||||
n_used_new[j] = n_used_old[j % old_shape[1]]
|
||||
|
||||
n_used_new = n_used_new[None, :]
|
||||
while len(n_used_new.shape) < len(new_shape):
|
||||
n_used_new = n_used_new.unsqueeze(-1)
|
||||
new_param /= n_used_new
|
||||
|
||||
sd[name] = new_param
|
||||
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||
sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
|
|
Loading…
Reference in a new issue