initial code to try and resume larger models from smaller models
This commit is contained in:
parent
a66b27b149
commit
250f89c1a3
1 changed files with 48 additions and 0 deletions
|
@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import LambdaLR
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import itertools
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from torchvision.utils import make_grid
|
from torchvision.utils import make_grid
|
||||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||||
|
@ -72,6 +73,7 @@ class DDPM(pl.LightningModule):
|
||||||
use_positional_encodings=False,
|
use_positional_encodings=False,
|
||||||
learn_logvar=False,
|
learn_logvar=False,
|
||||||
logvar_init=0.,
|
logvar_init=0.,
|
||||||
|
make_it_fit=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
||||||
|
@ -101,6 +103,7 @@ class DDPM(pl.LightningModule):
|
||||||
|
|
||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
|
self.make_it_fit = make_it_fit
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
||||||
|
|
||||||
|
@ -184,6 +187,7 @@ class DDPM(pl.LightningModule):
|
||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
|
@ -194,6 +198,50 @@ class DDPM(pl.LightningModule):
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
|
if self.make_it_fit:
|
||||||
|
n_params = len([name for name, _ in
|
||||||
|
itertools.chain(self.named_parameters(),
|
||||||
|
self.named_buffers())])
|
||||||
|
for name, param in tqdm(
|
||||||
|
itertools.chain(self.named_parameters(),
|
||||||
|
self.named_buffers()),
|
||||||
|
desc="Fitting old weights to new weights",
|
||||||
|
total=n_params
|
||||||
|
):
|
||||||
|
if not name in sd:
|
||||||
|
continue
|
||||||
|
old_shape = sd[name].shape
|
||||||
|
new_shape = param.shape
|
||||||
|
assert len(old_shape)==len(new_shape)
|
||||||
|
if len(new_shape) > 2:
|
||||||
|
# we only modify first two axes
|
||||||
|
assert new_shape[2:] == old_shape[2:]
|
||||||
|
# assumes first axis corresponds to output dim
|
||||||
|
if not new_shape == old_shape:
|
||||||
|
new_param = param.clone()
|
||||||
|
old_param = sd[name]
|
||||||
|
if len(new_shape) == 1:
|
||||||
|
for i in range(new_param.shape[0]):
|
||||||
|
new_param[i] = old_param[i % old_shape[0]]
|
||||||
|
elif len(new_shape) >= 2:
|
||||||
|
for i in range(new_param.shape[0]):
|
||||||
|
for j in range(new_param.shape[1]):
|
||||||
|
new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
|
||||||
|
|
||||||
|
n_used_old = torch.ones(old_shape[1])
|
||||||
|
for j in range(new_param.shape[1]):
|
||||||
|
n_used_old[j % old_shape[1]] += 1
|
||||||
|
n_used_new = torch.zeros(new_shape[1])
|
||||||
|
for j in range(new_param.shape[1]):
|
||||||
|
n_used_new[j] = n_used_old[j % old_shape[1]]
|
||||||
|
|
||||||
|
n_used_new = n_used_new[None, :]
|
||||||
|
while len(n_used_new.shape) < len(new_shape):
|
||||||
|
n_used_new = n_used_new.unsqueeze(-1)
|
||||||
|
new_param /= n_used_new
|
||||||
|
|
||||||
|
sd[name] = new_param
|
||||||
|
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||||
sd, strict=False)
|
sd, strict=False)
|
||||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
|
Loading…
Reference in a new issue