Finetuning (#11)
* simple datasets * add conversion script * finish fine tune example * update readme * update readme
This commit is contained in:
parent
704f564366
commit
f1293f9795
11 changed files with 942 additions and 15 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,6 +1,5 @@
|
|||
logs/
|
||||
dump/
|
||||
im-examples/
|
||||
outputs/
|
||||
flagged/
|
||||
*.egg-info
|
||||
|
|
28
README.md
28
README.md
|
@ -1,13 +1,33 @@
|
|||
# Experiments with Stable Diffusion
|
||||
|
||||
This repository extends and adds to the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion.
|
||||
|
||||
Currently it adds:
|
||||
|
||||
- [Fine tuning](#fine-tuning)
|
||||
- [Image variations](#image-variations)
|
||||
- [Conversion to Huggingface Diffusers](scripts/convert_sd_to_diffusers.py)
|
||||
|
||||
## Fine tuning
|
||||
|
||||
Makes it easy to fine tune Stable Diffusion on your own dataset. For example generating new Pokemon from text:
|
||||
|
||||
![](assets/pokemontage.jpg)
|
||||
|
||||
> Girl with a pearl earring, Cute Obama creature, Donald Trump, Boris Johnson, Totoro, Hello Kitty
|
||||
|
||||
|
||||
For a step by step guide see the [Lambda Labs examples repo](https://github.com/LambdaLabsML/examples).
|
||||
|
||||
## Image variations
|
||||
|
||||
[![](assets/img-vars.jpg)](https://twitter.com/Buntworthy/status/1561703483316781057)
|
||||
![](assets/im-vars-thin.jpg)
|
||||
|
||||
Try it out in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JqNbI_kDq_Gth2MIYdsphgNgyGIJxBgB?usp=sharing)
|
||||
[![Open Demo](https://img.shields.io/badge/%CE%BB-Open%20Demo-blueviolet)](https://47725.gradio.app/)
|
||||
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JqNbI_kDq_Gth2MIYdsphgNgyGIJxBgB?usp=sharing)
|
||||
[![Open in Spaces](https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-orange)]()
|
||||
|
||||
|
||||
_TODO describe in more detail_
|
||||
For more details on the Image Variation model see the [model card](https://huggingface.co/lambdalabs/stable-diffusion-image-conditioned).
|
||||
|
||||
- Get access to a Linux machine with a decent NVIDIA GPU (e.g. on [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud))
|
||||
- Clone this repo
|
||||
|
|
BIN
assets/pokemontage.jpg
Normal file
BIN
assets/pokemontage.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 148 KiB |
133
configs/stable-diffusion/pokemon.yaml
Normal file
133
configs/stable-diffusion/pokemon.yaml
Normal file
|
@ -0,0 +1,133 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "image"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
scale_factor: 0.18215
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
num_val_workers: 0 # Avoid a weird val dataloader issue
|
||||
train:
|
||||
target: ldm.data.simple.hf_dataset
|
||||
params:
|
||||
name: lambdalabs/pokemon-blip-captions
|
||||
image_transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 512
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.RandomCrop
|
||||
params:
|
||||
size: 512
|
||||
- target: torchvision.transforms.RandomHorizontalFlip
|
||||
validation:
|
||||
target: ldm.data.simple.TextOnly
|
||||
params:
|
||||
captions:
|
||||
- "A pokemon with green eyes, large wings, and a hat"
|
||||
- "A cute bunny rabbit"
|
||||
- "Yoda"
|
||||
- "An epic landscape photo of a mountain"
|
||||
output_size: 512
|
||||
n_gpus: 2 # small hack to sure we see all our samples
|
||||
|
||||
|
||||
lightning:
|
||||
find_unused_parameters: False
|
||||
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 2000
|
||||
save_top_k: -1
|
||||
monitor: null
|
||||
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 2000
|
||||
max_images: 4
|
||||
increase_log_steps: False
|
||||
log_first_step: True
|
||||
log_all_val: True
|
||||
log_images_kwargs:
|
||||
use_ema_scope: True
|
||||
inpaint: False
|
||||
plot_progressive_rows: False
|
||||
plot_diffusion_rows: False
|
||||
N: 4
|
||||
unconditional_guidance_scale: 3.0
|
||||
unconditional_guidance_label: [""]
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
num_sanity_val_steps: 0
|
||||
accumulate_grad_batches: 1
|
101
ldm/data/simple.py
Normal file
101
ldm/data/simple.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from pathlib import Path
|
||||
import json
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from einops import rearrange
|
||||
from ldm.util import instantiate_from_config
|
||||
from datasets import load_dataset
|
||||
|
||||
class FolderData(Dataset):
|
||||
def __init__(self, root_dir, caption_file, image_transforms, ext="jpg") -> None:
|
||||
self.root_dir = Path(root_dir)
|
||||
with open(caption_file, "rt") as f:
|
||||
captions = json.load(f)
|
||||
self.captions = captions
|
||||
|
||||
self.paths = list(self.root_dir.rglob(f"*.{ext}"))
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend([transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||
image_transforms = transforms.Compose(image_transforms)
|
||||
self.tform = image_transforms
|
||||
|
||||
# assert all(['full/' + str(x.name) in self.captions for x in self.paths])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.captions.keys())
|
||||
|
||||
def __getitem__(self, index):
|
||||
chosen = list(self.captions.keys())[index]
|
||||
im = Image.open(self.root_dir/chosen)
|
||||
im = self.process_im(im)
|
||||
caption = self.captions[chosen]
|
||||
if caption is None:
|
||||
caption = "old book illustration"
|
||||
return {"jpg": im, "txt": caption}
|
||||
|
||||
def process_im(self, im):
|
||||
im = im.convert("RGB")
|
||||
return self.tform(im)
|
||||
|
||||
def hf_dataset(
|
||||
name,
|
||||
image_transforms=[],
|
||||
image_column="image",
|
||||
text_column="text",
|
||||
split='train',
|
||||
image_key='image',
|
||||
caption_key='txt',
|
||||
):
|
||||
"""Make huggingface dataset with appropriate list of transforms applied
|
||||
"""
|
||||
ds = load_dataset(name, split=split)
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend([transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||
tform = transforms.Compose(image_transforms)
|
||||
|
||||
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
|
||||
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
|
||||
|
||||
def pre_process(examples):
|
||||
processed = {}
|
||||
processed[image_key] = [tform(im) for im in examples[image_column]]
|
||||
processed[caption_key] = examples[text_column]
|
||||
return processed
|
||||
|
||||
ds.set_transform(pre_process)
|
||||
return ds
|
||||
|
||||
class TextOnly(Dataset):
|
||||
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
|
||||
"""Returns only captions with dummy images"""
|
||||
self.output_size = output_size
|
||||
self.image_key = image_key
|
||||
self.caption_key = caption_key
|
||||
if isinstance(captions, Path):
|
||||
self.captions = self._load_caption_file(captions)
|
||||
else:
|
||||
self.captions = captions
|
||||
|
||||
if n_gpus > 1:
|
||||
# hack to make sure that all the captions appear on each gpu
|
||||
repeated = [n_gpus*[x] for x in self.captions]
|
||||
self.captions = []
|
||||
[self.captions.extend(x) for x in repeated]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.captions)
|
||||
|
||||
def __getitem__(self, index):
|
||||
dummy_im = torch.zeros(3, self.output_size, self.output_size)
|
||||
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
|
||||
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
|
||||
|
||||
def _load_caption_file(self, filename):
|
||||
with open(filename, 'rt') as f:
|
||||
captions = f.readlines()
|
||||
return [x.strip('\n') for x in captions]
|
|
@ -159,7 +159,8 @@ class DDIMSampler(object):
|
|||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0 = outs
|
||||
if callback: callback(i)
|
||||
if callback:
|
||||
img = callback(i, img, pred_x0)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
|
|
|
@ -1343,9 +1343,8 @@ class LatentDiffusion(DDPM):
|
|||
log["samples_x0_quantized"] = x_samples
|
||||
|
||||
if unconditional_guidance_scale > 1.0:
|
||||
# uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||
# FIXME
|
||||
uc = torch.zeros_like(c)
|
||||
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||
# uc = torch.zeros_like(c)
|
||||
with ema_scope("Sampling with classifier-free guidance"):
|
||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||
|
@ -1396,6 +1395,13 @@ class LatentDiffusion(DDPM):
|
|||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
params = list(self.model.parameters())
|
||||
# FIXME JP
|
||||
# params = []
|
||||
# from ldm.modules.attention import CrossAttention
|
||||
# for n, m in self.model.named_modules():
|
||||
# if isinstance(m, CrossAttention) and n.endswith('attn2'):
|
||||
# params.extend(m.parameters())
|
||||
# END FIXME JP
|
||||
if self.cond_stage_trainable:
|
||||
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
||||
params = params + list(self.cond_stage_model.parameters())
|
||||
|
|
|
@ -172,6 +172,19 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
|
||||
self.projection = torch.nn.Linear(768, 768)
|
||||
|
||||
def forward(self, text):
|
||||
z = self.embedder(text)
|
||||
return self.projection(z)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
|
@ -192,6 +205,14 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
|
|||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
|
||||
# I didn't call this originally, but seems like it was frozen anyway
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
|
|
20
main.py
20
main.py
|
@ -172,11 +172,15 @@ def worker_init_fn(_):
|
|||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
||||
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
||||
shuffle_val_dataloader=False):
|
||||
shuffle_val_dataloader=False, num_val_workers=None):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.dataset_configs = dict()
|
||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||
if num_val_workers is None:
|
||||
self.num_val_workers = self.num_workers
|
||||
else:
|
||||
self.num_val_workers = num_val_workers
|
||||
self.use_worker_init_fn = use_worker_init_fn
|
||||
if train is not None:
|
||||
self.dataset_configs["train"] = train
|
||||
|
@ -221,7 +225,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
|||
init_fn = None
|
||||
return DataLoader(self.datasets["validation"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
num_workers=self.num_val_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle)
|
||||
|
||||
|
@ -304,7 +308,7 @@ class SetupCallback(Callback):
|
|||
class ImageLogger(Callback):
|
||||
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
||||
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
||||
log_images_kwargs=None):
|
||||
log_images_kwargs=None, log_all_val=False):
|
||||
super().__init__()
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
|
@ -320,6 +324,7 @@ class ImageLogger(Callback):
|
|||
self.log_on_batch_idx = log_on_batch_idx
|
||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||
self.log_first_step = log_first_step
|
||||
self.log_all_val = log_all_val
|
||||
|
||||
@rank_zero_only
|
||||
def _testtube(self, pl_module, images, batch_idx, split):
|
||||
|
@ -354,10 +359,13 @@ class ImageLogger(Callback):
|
|||
|
||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
||||
if self.log_all_val and split == "val":
|
||||
should_log = True
|
||||
else:
|
||||
should_log = self.check_frequency(check_idx)
|
||||
if (should_log and # batch_idx % self.batch_freq == 0
|
||||
hasattr(pl_module, "log_images") and
|
||||
callable(pl_module.log_images) and
|
||||
batch_idx > 5 and
|
||||
self.max_images > 0):
|
||||
logger = type(pl_module.logger)
|
||||
|
||||
|
@ -687,7 +695,7 @@ if __name__ == "__main__":
|
|||
}
|
||||
},
|
||||
}
|
||||
default_logger_cfg = default_logger_cfgs["wandb"]
|
||||
default_logger_cfg = default_logger_cfgs["testtube"]
|
||||
if "logger" in lightning_config:
|
||||
logger_cfg = lightning_config.logger
|
||||
else:
|
||||
|
|
|
@ -15,6 +15,8 @@ webdataset==0.2.5
|
|||
torchmetrics==0.6.0
|
||||
fire==0.4.0
|
||||
gradio==3.1.4
|
||||
diffusers==0.3.0
|
||||
datasets[vision]==2.4.0
|
||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
-e .
|
||||
|
|
636
scripts/convert_sd_to_diffusers.py
Normal file
636
scripts/convert_sd_to_diffusers.py
Normal file
|
@ -0,0 +1,636 @@
|
|||
# coding=utf-8
|
||||
# Modified by Justin Pinkney
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Conversion script for the LDM checkpoints. """
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError:
|
||||
raise ImportError("OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`.")
|
||||
|
||||
from transformers import BertTokenizerFast, CLIPTokenizer, CLIPTextModel
|
||||
from transformers import CLIPFeatureExtractor
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel, LDMBertConfig
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
"""
|
||||
if n_shave_prefix_segments >= 0:
|
||||
return '.'.join(path.split('.')[n_shave_prefix_segments:])
|
||||
else:
|
||||
return '.'.join(path.split('.')[:n_shave_prefix_segments])
|
||||
|
||||
|
||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item.replace('in_layers.0', 'norm1')
|
||||
new_item = new_item.replace('in_layers.2', 'conv1')
|
||||
|
||||
new_item = new_item.replace('out_layers.0', 'norm2')
|
||||
new_item = new_item.replace('out_layers.3', 'conv2')
|
||||
|
||||
new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
|
||||
new_item = new_item.replace('skip_connection', 'conv_shortcut')
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({'old': old_item, 'new': new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({'old': old_item, 'new': new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||
|
||||
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||
|
||||
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({'old': old_item, 'new': new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
"""
|
||||
mapping = []
|
||||
for old_item in old_list:
|
||||
new_item = old_item
|
||||
|
||||
new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||
new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||
|
||||
new_item = new_item.replace('q.weight', 'query.weight')
|
||||
new_item = new_item.replace('q.bias', 'query.bias')
|
||||
|
||||
new_item = new_item.replace('k.weight', 'key.weight')
|
||||
new_item = new_item.replace('k.bias', 'key.bias')
|
||||
|
||||
new_item = new_item.replace('v.weight', 'value.weight')
|
||||
new_item = new_item.replace('v.bias', 'value.bias')
|
||||
|
||||
new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||
new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
mapping.append({'old': old_item, 'new': new_item})
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||
to them. It splits attention layers, and takes into account additional replacements
|
||||
that may arise.
|
||||
|
||||
Assigns the weights to the new checkpoint.
|
||||
"""
|
||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
# Splits the attention layers into three variables.
|
||||
if attention_paths_to_split is not None:
|
||||
for path, path_map in attention_paths_to_split.items():
|
||||
old_tensor = old_checkpoint[path]
|
||||
channels = old_tensor.shape[0] // 3
|
||||
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
|
||||
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map['query']] = query.reshape(target_shape)
|
||||
checkpoint[path_map['key']] = key.reshape(target_shape)
|
||||
checkpoint[path_map['value']] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path['new']
|
||||
|
||||
# These have already been assigned
|
||||
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||
continue
|
||||
|
||||
# Global renaming happens here
|
||||
new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0')
|
||||
new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0')
|
||||
new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1')
|
||||
|
||||
if additional_replacements is not None:
|
||||
for replacement in additional_replacements:
|
||||
new_path = new_path.replace(replacement['old'], replacement['new'])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path['old']]
|
||||
|
||||
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def create_unet_diffusers_config(original_config):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
config = dict(
|
||||
sample_size=64,
|
||||
in_channels=unet_params.in_channels,
|
||||
out_channels=unet_params.out_channels,
|
||||
down_block_types=tuple(down_block_types),
|
||||
up_block_types=tuple(up_block_types),
|
||||
block_out_channels=tuple(block_out_channels),
|
||||
layers_per_block=unet_params.num_res_blocks,
|
||||
cross_attention_dim=unet_params.context_dim,
|
||||
attention_head_dim=unet_params.num_heads,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def create_vae_diffusers_config(original_config):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
latent_channles = original_config.model.params.first_stage_config.params.embed_dim
|
||||
|
||||
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
|
||||
config = dict(
|
||||
sample_size=512,
|
||||
in_channels=vae_params.in_channels,
|
||||
out_channels=vae_params.out_ch,
|
||||
down_block_types=tuple(down_block_types),
|
||||
up_block_types=tuple(up_block_types),
|
||||
block_out_channels=tuple(block_out_channels),
|
||||
latent_channels=vae_params.z_channels,
|
||||
layers_per_block=vae_params.num_res_blocks,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def create_diffusers_schedular(original_config):
|
||||
schedular = PNDMScheduler(
|
||||
num_train_timesteps=original_config.model.params.timesteps,
|
||||
beta_start=original_config.model.params.linear_start,
|
||||
beta_end=original_config.model.params.linear_end,
|
||||
beta_schedule="scaled_linear",
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
return schedular
|
||||
|
||||
|
||||
def create_ldm_bert_config(original_config):
|
||||
bert_params = original_config.model.parms.cond_stage_config.params
|
||||
config = LDMBertConfig(
|
||||
d_model=bert_params.n_embed,
|
||||
encoder_layers=bert_params.n_layer,
|
||||
encoder_ffn_dim=bert_params.n_embed * 4,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
unet_key = "model.diffusion_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict['time_embed.0.weight']
|
||||
new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict['time_embed.0.bias']
|
||||
new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict['time_embed.2.weight']
|
||||
new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict['time_embed.2.bias']
|
||||
|
||||
new_checkpoint['conv_in.weight'] = unet_state_dict['input_blocks.0.0.weight']
|
||||
new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias']
|
||||
|
||||
new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight']
|
||||
new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias']
|
||||
new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight']
|
||||
new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias']
|
||||
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'input_blocks' in layer})
|
||||
input_blocks = {layer_id: [key for key in unet_state_dict if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)}
|
||||
|
||||
# Retrieves the keys for the middle blocks only
|
||||
num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'middle_block' in layer})
|
||||
middle_blocks = {layer_id: [key for key in unet_state_dict if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)}
|
||||
|
||||
# Retrieves the keys for the output blocks only
|
||||
num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'output_blocks' in layer})
|
||||
output_blocks = {layer_id: [key for key in unet_state_dict if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)}
|
||||
|
||||
for i in range(1, num_input_blocks):
|
||||
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
||||
layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1)
|
||||
|
||||
resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key and f'input_blocks.{i}.0.op' not in key]
|
||||
attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key]
|
||||
|
||||
if f'input_blocks.{i}.0.op.weight' in unet_state_dict:
|
||||
new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.weight'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight')
|
||||
new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.bias'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias')
|
||||
|
||||
paths = renew_resnet_paths(resnets)
|
||||
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {'old': f'input_blocks.{i}.1', 'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
|
||||
resnet_0 = middle_blocks[0]
|
||||
attentions = middle_blocks[1]
|
||||
resnet_1 = middle_blocks[2]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnet_0)
|
||||
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
||||
|
||||
resnet_1_paths = renew_resnet_paths(resnet_1)
|
||||
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
||||
|
||||
attentions_paths = renew_attention_paths(attentions)
|
||||
meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'}
|
||||
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
for i in range(num_output_blocks):
|
||||
block_id = i // (config['layers_per_block'] + 1)
|
||||
layer_in_block_id = i % (config['layers_per_block'] + 1)
|
||||
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||
output_block_list = {}
|
||||
|
||||
for layer in output_block_layers:
|
||||
layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1)
|
||||
if layer_id in output_block_list:
|
||||
output_block_list[layer_id].append(layer_name)
|
||||
else:
|
||||
output_block_list[layer_id] = [layer_name]
|
||||
|
||||
if len(output_block_list) > 1:
|
||||
resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key]
|
||||
attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key]
|
||||
|
||||
resnet_0_paths = renew_resnet_paths(resnets)
|
||||
paths = renew_resnet_paths(resnets)
|
||||
|
||||
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
if ['conv.weight', 'conv.bias'] in output_block_list.values():
|
||||
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight']
|
||||
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias']
|
||||
|
||||
# Clear attentions as they have been attributed above.
|
||||
if len(attentions) == 2:
|
||||
attentions = []
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {
|
||||
'old': f'output_blocks.{i}.1',
|
||||
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
|
||||
}
|
||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||
else:
|
||||
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||
for path in resnet_0_paths:
|
||||
old_path = '.'.join(['output_blocks', str(i), path['old']])
|
||||
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
|
||||
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
||||
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
||||
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
||||
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
||||
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
||||
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
||||
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
||||
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
||||
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
||||
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
||||
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
||||
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
||||
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
||||
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'encoder.down' in layer})
|
||||
down_blocks = {layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'decoder.up' in layer})
|
||||
up_blocks = {layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
||||
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f'down.{i}' in key and f"down.{i}.downsample" not in key]
|
||||
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {'old': f'down.{i}.block', 'new': f'down_blocks.{i}.resnets'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [key for key in up_blocks[block_id] if f'up.{block_id}' in key and f"up.{block_id}.upsample" not in key]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {'old': f'up.{block_id}.block', 'new': f'up_blocks.{i}.resnets'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
|
||||
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
||||
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
||||
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
||||
|
||||
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
||||
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
||||
|
||||
|
||||
def _copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
|
||||
def _copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
||||
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
||||
|
||||
# copy attn
|
||||
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
||||
|
||||
# copy MLP
|
||||
pt_mlp = pt_layer[1][1]
|
||||
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
||||
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
||||
|
||||
|
||||
def _copy_layers(hf_layers, pt_layers):
|
||||
for i, hf_layer in enumerate(hf_layers):
|
||||
if i != 0: i += i
|
||||
pt_layer = pt_layers[i:i+2]
|
||||
_copy_layer(hf_layer, pt_layer)
|
||||
|
||||
hf_model = LDMBertModel(config).eval()
|
||||
|
||||
# copy embeds
|
||||
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
||||
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
||||
|
||||
# copy layer norm
|
||||
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
||||
|
||||
# copy hidden layers
|
||||
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
||||
|
||||
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
||||
|
||||
return hf_model
|
||||
|
||||
def copy_ema_weights(checkpoint, config):
|
||||
"""Copies ema weights over the original weights in a state_dict
|
||||
Only applies to the unet
|
||||
"""
|
||||
from ldm.util import instantiate_from_config
|
||||
model = instantiate_from_config(config.model)
|
||||
for k, v in checkpoint.items():
|
||||
if k.startswith('model.'):
|
||||
model_key = k[6:]
|
||||
ema_key = model.model_ema.m_name2s_name[model_key]
|
||||
ema_weight = checkpoint["model_ema." + ema_key]
|
||||
print(f"copying ema weight {ema_key} to {model_key}")
|
||||
checkpoint[k] = ema_weight
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--original_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_ema", action="store_true", help="use EMA weights for conversion",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
original_config = OmegaConf.load(args.original_config_file)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")["state_dict"]
|
||||
|
||||
if args.use_ema:
|
||||
checkpoint = copy_ema_weights(checkpoint, original_config)
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config)
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
|
||||
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
# Convert the text model.
|
||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if text_model_type == "FrozenCLIPEmbedder":
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
else:
|
||||
# TODO: update the convert function to use the state_dict without the model instance.
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
scheduler = create_diffusers_schedular(original_config)
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
Loading…
Reference in a new issue