Finetuning (#11)
* simple datasets * add conversion script * finish fine tune example * update readme * update readme
This commit is contained in:
parent
704f564366
commit
f1293f9795
11 changed files with 942 additions and 15 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,6 +1,5 @@
|
||||||
logs/
|
logs/
|
||||||
dump/
|
dump/
|
||||||
im-examples/
|
|
||||||
outputs/
|
outputs/
|
||||||
flagged/
|
flagged/
|
||||||
*.egg-info
|
*.egg-info
|
||||||
|
|
28
README.md
28
README.md
|
@ -1,13 +1,33 @@
|
||||||
# Experiments with Stable Diffusion
|
# Experiments with Stable Diffusion
|
||||||
|
|
||||||
|
This repository extends and adds to the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion.
|
||||||
|
|
||||||
|
Currently it adds:
|
||||||
|
|
||||||
|
- [Fine tuning](#fine-tuning)
|
||||||
|
- [Image variations](#image-variations)
|
||||||
|
- [Conversion to Huggingface Diffusers](scripts/convert_sd_to_diffusers.py)
|
||||||
|
|
||||||
|
## Fine tuning
|
||||||
|
|
||||||
|
Makes it easy to fine tune Stable Diffusion on your own dataset. For example generating new Pokemon from text:
|
||||||
|
|
||||||
|
![](assets/pokemontage.jpg)
|
||||||
|
|
||||||
|
> Girl with a pearl earring, Cute Obama creature, Donald Trump, Boris Johnson, Totoro, Hello Kitty
|
||||||
|
|
||||||
|
|
||||||
|
For a step by step guide see the [Lambda Labs examples repo](https://github.com/LambdaLabsML/examples).
|
||||||
|
|
||||||
## Image variations
|
## Image variations
|
||||||
|
|
||||||
[![](assets/img-vars.jpg)](https://twitter.com/Buntworthy/status/1561703483316781057)
|
![](assets/im-vars-thin.jpg)
|
||||||
|
|
||||||
Try it out in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JqNbI_kDq_Gth2MIYdsphgNgyGIJxBgB?usp=sharing)
|
[![Open Demo](https://img.shields.io/badge/%CE%BB-Open%20Demo-blueviolet)](https://47725.gradio.app/)
|
||||||
|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JqNbI_kDq_Gth2MIYdsphgNgyGIJxBgB?usp=sharing)
|
||||||
|
[![Open in Spaces](https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-orange)]()
|
||||||
|
|
||||||
|
For more details on the Image Variation model see the [model card](https://huggingface.co/lambdalabs/stable-diffusion-image-conditioned).
|
||||||
_TODO describe in more detail_
|
|
||||||
|
|
||||||
- Get access to a Linux machine with a decent NVIDIA GPU (e.g. on [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud))
|
- Get access to a Linux machine with a decent NVIDIA GPU (e.g. on [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud))
|
||||||
- Clone this repo
|
- Clone this repo
|
||||||
|
|
BIN
assets/pokemontage.jpg
Normal file
BIN
assets/pokemontage.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 148 KiB |
133
configs/stable-diffusion/pokemon.yaml
Normal file
133
configs/stable-diffusion/pokemon.yaml
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "image"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
num_val_workers: 0 # Avoid a weird val dataloader issue
|
||||||
|
train:
|
||||||
|
target: ldm.data.simple.hf_dataset
|
||||||
|
params:
|
||||||
|
name: lambdalabs/pokemon-blip-captions
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
- target: torchvision.transforms.RandomHorizontalFlip
|
||||||
|
validation:
|
||||||
|
target: ldm.data.simple.TextOnly
|
||||||
|
params:
|
||||||
|
captions:
|
||||||
|
- "A pokemon with green eyes, large wings, and a hat"
|
||||||
|
- "A cute bunny rabbit"
|
||||||
|
- "Yoda"
|
||||||
|
- "An epic landscape photo of a mountain"
|
||||||
|
output_size: 512
|
||||||
|
n_gpus: 2 # small hack to sure we see all our samples
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
save_top_k: -1
|
||||||
|
monitor: null
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
log_all_val: True
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: True
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
101
ldm/data/simple.py
Normal file
101
ldm/data/simple.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from einops import rearrange
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
class FolderData(Dataset):
|
||||||
|
def __init__(self, root_dir, caption_file, image_transforms, ext="jpg") -> None:
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
with open(caption_file, "rt") as f:
|
||||||
|
captions = json.load(f)
|
||||||
|
self.captions = captions
|
||||||
|
|
||||||
|
self.paths = list(self.root_dir.rglob(f"*.{ext}"))
|
||||||
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||||
|
image_transforms.extend([transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||||
|
image_transforms = transforms.Compose(image_transforms)
|
||||||
|
self.tform = image_transforms
|
||||||
|
|
||||||
|
# assert all(['full/' + str(x.name) in self.captions for x in self.paths])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.captions.keys())
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
chosen = list(self.captions.keys())[index]
|
||||||
|
im = Image.open(self.root_dir/chosen)
|
||||||
|
im = self.process_im(im)
|
||||||
|
caption = self.captions[chosen]
|
||||||
|
if caption is None:
|
||||||
|
caption = "old book illustration"
|
||||||
|
return {"jpg": im, "txt": caption}
|
||||||
|
|
||||||
|
def process_im(self, im):
|
||||||
|
im = im.convert("RGB")
|
||||||
|
return self.tform(im)
|
||||||
|
|
||||||
|
def hf_dataset(
|
||||||
|
name,
|
||||||
|
image_transforms=[],
|
||||||
|
image_column="image",
|
||||||
|
text_column="text",
|
||||||
|
split='train',
|
||||||
|
image_key='image',
|
||||||
|
caption_key='txt',
|
||||||
|
):
|
||||||
|
"""Make huggingface dataset with appropriate list of transforms applied
|
||||||
|
"""
|
||||||
|
ds = load_dataset(name, split=split)
|
||||||
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||||
|
image_transforms.extend([transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||||
|
tform = transforms.Compose(image_transforms)
|
||||||
|
|
||||||
|
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
|
||||||
|
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
|
||||||
|
|
||||||
|
def pre_process(examples):
|
||||||
|
processed = {}
|
||||||
|
processed[image_key] = [tform(im) for im in examples[image_column]]
|
||||||
|
processed[caption_key] = examples[text_column]
|
||||||
|
return processed
|
||||||
|
|
||||||
|
ds.set_transform(pre_process)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
class TextOnly(Dataset):
|
||||||
|
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
|
||||||
|
"""Returns only captions with dummy images"""
|
||||||
|
self.output_size = output_size
|
||||||
|
self.image_key = image_key
|
||||||
|
self.caption_key = caption_key
|
||||||
|
if isinstance(captions, Path):
|
||||||
|
self.captions = self._load_caption_file(captions)
|
||||||
|
else:
|
||||||
|
self.captions = captions
|
||||||
|
|
||||||
|
if n_gpus > 1:
|
||||||
|
# hack to make sure that all the captions appear on each gpu
|
||||||
|
repeated = [n_gpus*[x] for x in self.captions]
|
||||||
|
self.captions = []
|
||||||
|
[self.captions.extend(x) for x in repeated]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.captions)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
dummy_im = torch.zeros(3, self.output_size, self.output_size)
|
||||||
|
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
|
||||||
|
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
|
||||||
|
|
||||||
|
def _load_caption_file(self, filename):
|
||||||
|
with open(filename, 'rt') as f:
|
||||||
|
captions = f.readlines()
|
||||||
|
return [x.strip('\n') for x in captions]
|
|
@ -159,7 +159,8 @@ class DDIMSampler(object):
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
dynamic_threshold=dynamic_threshold)
|
dynamic_threshold=dynamic_threshold)
|
||||||
img, pred_x0 = outs
|
img, pred_x0 = outs
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
|
img = callback(i, img, pred_x0)
|
||||||
if img_callback: img_callback(pred_x0, i)
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
if index % log_every_t == 0 or index == total_steps - 1:
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
|
|
@ -1343,9 +1343,8 @@ class LatentDiffusion(DDPM):
|
||||||
log["samples_x0_quantized"] = x_samples
|
log["samples_x0_quantized"] = x_samples
|
||||||
|
|
||||||
if unconditional_guidance_scale > 1.0:
|
if unconditional_guidance_scale > 1.0:
|
||||||
# uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||||
# FIXME
|
# uc = torch.zeros_like(c)
|
||||||
uc = torch.zeros_like(c)
|
|
||||||
with ema_scope("Sampling with classifier-free guidance"):
|
with ema_scope("Sampling with classifier-free guidance"):
|
||||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||||
ddim_steps=ddim_steps, eta=ddim_eta,
|
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||||
|
@ -1396,6 +1395,13 @@ class LatentDiffusion(DDPM):
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
lr = self.learning_rate
|
lr = self.learning_rate
|
||||||
params = list(self.model.parameters())
|
params = list(self.model.parameters())
|
||||||
|
# FIXME JP
|
||||||
|
# params = []
|
||||||
|
# from ldm.modules.attention import CrossAttention
|
||||||
|
# for n, m in self.model.named_modules():
|
||||||
|
# if isinstance(m, CrossAttention) and n.endswith('attn2'):
|
||||||
|
# params.extend(m.parameters())
|
||||||
|
# END FIXME JP
|
||||||
if self.cond_stage_trainable:
|
if self.cond_stage_trainable:
|
||||||
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
|
||||||
params = params + list(self.cond_stage_model.parameters())
|
params = params + list(self.cond_stage_model.parameters())
|
||||||
|
|
|
@ -172,6 +172,19 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
|
||||||
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||||
|
super().__init__()
|
||||||
|
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
|
||||||
|
self.projection = torch.nn.Linear(768, 768)
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
z = self.embedder(text)
|
||||||
|
return self.projection(z)
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self(text)
|
||||||
|
|
||||||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||||
"""
|
"""
|
||||||
Uses the CLIP image encoder.
|
Uses the CLIP image encoder.
|
||||||
|
@ -192,6 +205,14 @@ class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
|
||||||
|
# I didn't call this originally, but seems like it was frozen anyway
|
||||||
|
self.freeze()
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.transformer = self.transformer.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
def preprocess(self, x):
|
def preprocess(self, x):
|
||||||
# Expects inputs in the range -1, 1
|
# Expects inputs in the range -1, 1
|
||||||
x = kornia.geometry.resize(x, (224, 224),
|
x = kornia.geometry.resize(x, (224, 224),
|
||||||
|
|
20
main.py
20
main.py
|
@ -172,11 +172,15 @@ def worker_init_fn(_):
|
||||||
class DataModuleFromConfig(pl.LightningDataModule):
|
class DataModuleFromConfig(pl.LightningDataModule):
|
||||||
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
||||||
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
||||||
shuffle_val_dataloader=False):
|
shuffle_val_dataloader=False, num_val_workers=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.dataset_configs = dict()
|
self.dataset_configs = dict()
|
||||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||||
|
if num_val_workers is None:
|
||||||
|
self.num_val_workers = self.num_workers
|
||||||
|
else:
|
||||||
|
self.num_val_workers = num_val_workers
|
||||||
self.use_worker_init_fn = use_worker_init_fn
|
self.use_worker_init_fn = use_worker_init_fn
|
||||||
if train is not None:
|
if train is not None:
|
||||||
self.dataset_configs["train"] = train
|
self.dataset_configs["train"] = train
|
||||||
|
@ -221,7 +225,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||||
init_fn = None
|
init_fn = None
|
||||||
return DataLoader(self.datasets["validation"],
|
return DataLoader(self.datasets["validation"],
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_val_workers,
|
||||||
worker_init_fn=init_fn,
|
worker_init_fn=init_fn,
|
||||||
shuffle=shuffle)
|
shuffle=shuffle)
|
||||||
|
|
||||||
|
@ -304,7 +308,7 @@ class SetupCallback(Callback):
|
||||||
class ImageLogger(Callback):
|
class ImageLogger(Callback):
|
||||||
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
||||||
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
||||||
log_images_kwargs=None):
|
log_images_kwargs=None, log_all_val=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rescale = rescale
|
self.rescale = rescale
|
||||||
self.batch_freq = batch_frequency
|
self.batch_freq = batch_frequency
|
||||||
|
@ -320,6 +324,7 @@ class ImageLogger(Callback):
|
||||||
self.log_on_batch_idx = log_on_batch_idx
|
self.log_on_batch_idx = log_on_batch_idx
|
||||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||||
self.log_first_step = log_first_step
|
self.log_first_step = log_first_step
|
||||||
|
self.log_all_val = log_all_val
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
def _testtube(self, pl_module, images, batch_idx, split):
|
def _testtube(self, pl_module, images, batch_idx, split):
|
||||||
|
@ -354,10 +359,13 @@ class ImageLogger(Callback):
|
||||||
|
|
||||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||||
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
if self.log_all_val and split == "val":
|
||||||
|
should_log = True
|
||||||
|
else:
|
||||||
|
should_log = self.check_frequency(check_idx)
|
||||||
|
if (should_log and # batch_idx % self.batch_freq == 0
|
||||||
hasattr(pl_module, "log_images") and
|
hasattr(pl_module, "log_images") and
|
||||||
callable(pl_module.log_images) and
|
callable(pl_module.log_images) and
|
||||||
batch_idx > 5 and
|
|
||||||
self.max_images > 0):
|
self.max_images > 0):
|
||||||
logger = type(pl_module.logger)
|
logger = type(pl_module.logger)
|
||||||
|
|
||||||
|
@ -687,7 +695,7 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
default_logger_cfg = default_logger_cfgs["wandb"]
|
default_logger_cfg = default_logger_cfgs["testtube"]
|
||||||
if "logger" in lightning_config:
|
if "logger" in lightning_config:
|
||||||
logger_cfg = lightning_config.logger
|
logger_cfg = lightning_config.logger
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -15,6 +15,8 @@ webdataset==0.2.5
|
||||||
torchmetrics==0.6.0
|
torchmetrics==0.6.0
|
||||||
fire==0.4.0
|
fire==0.4.0
|
||||||
gradio==3.1.4
|
gradio==3.1.4
|
||||||
|
diffusers==0.3.0
|
||||||
|
datasets[vision]==2.4.0
|
||||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
-e .
|
-e .
|
||||||
|
|
636
scripts/convert_sd_to_diffusers.py
Normal file
636
scripts/convert_sd_to_diffusers.py
Normal file
|
@ -0,0 +1,636 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Modified by Justin Pinkney
|
||||||
|
# Copyright 2022 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" Conversion script for the LDM checkpoints. """
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`.")
|
||||||
|
|
||||||
|
from transformers import BertTokenizerFast, CLIPTokenizer, CLIPTextModel
|
||||||
|
from transformers import CLIPFeatureExtractor
|
||||||
|
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, PNDMScheduler
|
||||||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertModel, LDMBertConfig
|
||||||
|
|
||||||
|
|
||||||
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
|
"""
|
||||||
|
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||||
|
"""
|
||||||
|
if n_shave_prefix_segments >= 0:
|
||||||
|
return '.'.join(path.split('.')[n_shave_prefix_segments:])
|
||||||
|
else:
|
||||||
|
return '.'.join(path.split('.')[:n_shave_prefix_segments])
|
||||||
|
|
||||||
|
|
||||||
|
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item.replace('in_layers.0', 'norm1')
|
||||||
|
new_item = new_item.replace('in_layers.2', 'conv1')
|
||||||
|
|
||||||
|
new_item = new_item.replace('out_layers.0', 'norm2')
|
||||||
|
new_item = new_item.replace('out_layers.3', 'conv2')
|
||||||
|
|
||||||
|
new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
|
||||||
|
new_item = new_item.replace('skip_connection', 'conv_shortcut')
|
||||||
|
|
||||||
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({'old': old_item, 'new': new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item
|
||||||
|
|
||||||
|
new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
|
||||||
|
|
||||||
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({'old': old_item, 'new': new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item
|
||||||
|
|
||||||
|
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||||
|
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||||
|
|
||||||
|
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||||
|
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||||
|
|
||||||
|
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({'old': old_item, 'new': new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||||
|
"""
|
||||||
|
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||||
|
"""
|
||||||
|
mapping = []
|
||||||
|
for old_item in old_list:
|
||||||
|
new_item = old_item
|
||||||
|
|
||||||
|
new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
||||||
|
new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
||||||
|
|
||||||
|
new_item = new_item.replace('q.weight', 'query.weight')
|
||||||
|
new_item = new_item.replace('q.bias', 'query.bias')
|
||||||
|
|
||||||
|
new_item = new_item.replace('k.weight', 'key.weight')
|
||||||
|
new_item = new_item.replace('k.bias', 'key.bias')
|
||||||
|
|
||||||
|
new_item = new_item.replace('v.weight', 'value.weight')
|
||||||
|
new_item = new_item.replace('v.bias', 'value.bias')
|
||||||
|
|
||||||
|
new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
||||||
|
new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
||||||
|
|
||||||
|
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||||
|
|
||||||
|
mapping.append({'old': old_item, 'new': new_item})
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
|
||||||
|
"""
|
||||||
|
This does the final conversion step: take locally converted weights and apply a global renaming
|
||||||
|
to them. It splits attention layers, and takes into account additional replacements
|
||||||
|
that may arise.
|
||||||
|
|
||||||
|
Assigns the weights to the new checkpoint.
|
||||||
|
"""
|
||||||
|
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||||
|
|
||||||
|
# Splits the attention layers into three variables.
|
||||||
|
if attention_paths_to_split is not None:
|
||||||
|
for path, path_map in attention_paths_to_split.items():
|
||||||
|
old_tensor = old_checkpoint[path]
|
||||||
|
channels = old_tensor.shape[0] // 3
|
||||||
|
|
||||||
|
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||||
|
|
||||||
|
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||||
|
|
||||||
|
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
||||||
|
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||||
|
|
||||||
|
checkpoint[path_map['query']] = query.reshape(target_shape)
|
||||||
|
checkpoint[path_map['key']] = key.reshape(target_shape)
|
||||||
|
checkpoint[path_map['value']] = value.reshape(target_shape)
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
new_path = path['new']
|
||||||
|
|
||||||
|
# These have already been assigned
|
||||||
|
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Global renaming happens here
|
||||||
|
new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0')
|
||||||
|
new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0')
|
||||||
|
new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1')
|
||||||
|
|
||||||
|
if additional_replacements is not None:
|
||||||
|
for replacement in additional_replacements:
|
||||||
|
new_path = new_path.replace(replacement['old'], replacement['new'])
|
||||||
|
|
||||||
|
# proj_attn.weight has to be converted from conv 1D to linear
|
||||||
|
if "proj_attn.weight" in new_path:
|
||||||
|
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
|
||||||
|
else:
|
||||||
|
checkpoint[new_path] = old_checkpoint[path['old']]
|
||||||
|
|
||||||
|
|
||||||
|
def conv_attn_to_linear(checkpoint):
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||||
|
for key in keys:
|
||||||
|
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||||
|
if checkpoint[key].ndim > 2:
|
||||||
|
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||||
|
elif "proj_attn.weight" in key:
|
||||||
|
if checkpoint[key].ndim > 2:
|
||||||
|
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def create_unet_diffusers_config(original_config):
|
||||||
|
"""
|
||||||
|
Creates a config for the diffusers based on the config of the LDM model.
|
||||||
|
"""
|
||||||
|
unet_params = original_config.model.params.unet_config.params
|
||||||
|
|
||||||
|
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||||
|
|
||||||
|
down_block_types = []
|
||||||
|
resolution = 1
|
||||||
|
for i in range(len(block_out_channels)):
|
||||||
|
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||||
|
down_block_types.append(block_type)
|
||||||
|
if i != len(block_out_channels) - 1:
|
||||||
|
resolution *= 2
|
||||||
|
|
||||||
|
up_block_types = []
|
||||||
|
for i in range(len(block_out_channels)):
|
||||||
|
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||||
|
up_block_types.append(block_type)
|
||||||
|
resolution //= 2
|
||||||
|
|
||||||
|
config = dict(
|
||||||
|
sample_size=64,
|
||||||
|
in_channels=unet_params.in_channels,
|
||||||
|
out_channels=unet_params.out_channels,
|
||||||
|
down_block_types=tuple(down_block_types),
|
||||||
|
up_block_types=tuple(up_block_types),
|
||||||
|
block_out_channels=tuple(block_out_channels),
|
||||||
|
layers_per_block=unet_params.num_res_blocks,
|
||||||
|
cross_attention_dim=unet_params.context_dim,
|
||||||
|
attention_head_dim=unet_params.num_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_vae_diffusers_config(original_config):
|
||||||
|
"""
|
||||||
|
Creates a config for the diffusers based on the config of the LDM model.
|
||||||
|
"""
|
||||||
|
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||||
|
latent_channles = original_config.model.params.first_stage_config.params.embed_dim
|
||||||
|
|
||||||
|
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||||
|
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||||
|
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||||
|
|
||||||
|
config = dict(
|
||||||
|
sample_size=512,
|
||||||
|
in_channels=vae_params.in_channels,
|
||||||
|
out_channels=vae_params.out_ch,
|
||||||
|
down_block_types=tuple(down_block_types),
|
||||||
|
up_block_types=tuple(up_block_types),
|
||||||
|
block_out_channels=tuple(block_out_channels),
|
||||||
|
latent_channels=vae_params.z_channels,
|
||||||
|
layers_per_block=vae_params.num_res_blocks,
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_diffusers_schedular(original_config):
|
||||||
|
schedular = PNDMScheduler(
|
||||||
|
num_train_timesteps=original_config.model.params.timesteps,
|
||||||
|
beta_start=original_config.model.params.linear_start,
|
||||||
|
beta_end=original_config.model.params.linear_end,
|
||||||
|
beta_schedule="scaled_linear",
|
||||||
|
skip_prk_steps=True,
|
||||||
|
)
|
||||||
|
return schedular
|
||||||
|
|
||||||
|
|
||||||
|
def create_ldm_bert_config(original_config):
|
||||||
|
bert_params = original_config.model.parms.cond_stage_config.params
|
||||||
|
config = LDMBertConfig(
|
||||||
|
d_model=bert_params.n_embed,
|
||||||
|
encoder_layers=bert_params.n_layer,
|
||||||
|
encoder_ffn_dim=bert_params.n_embed * 4,
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_unet_checkpoint(checkpoint, config):
|
||||||
|
"""
|
||||||
|
Takes a state dict and a config, and returns a converted checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# extract state_dict for UNet
|
||||||
|
unet_state_dict = {}
|
||||||
|
unet_key = "model.diffusion_model."
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith(unet_key):
|
||||||
|
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||||
|
|
||||||
|
new_checkpoint = {}
|
||||||
|
|
||||||
|
new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict['time_embed.0.weight']
|
||||||
|
new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict['time_embed.0.bias']
|
||||||
|
new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict['time_embed.2.weight']
|
||||||
|
new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict['time_embed.2.bias']
|
||||||
|
|
||||||
|
new_checkpoint['conv_in.weight'] = unet_state_dict['input_blocks.0.0.weight']
|
||||||
|
new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias']
|
||||||
|
|
||||||
|
new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight']
|
||||||
|
new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias']
|
||||||
|
new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight']
|
||||||
|
new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias']
|
||||||
|
|
||||||
|
# Retrieves the keys for the input blocks only
|
||||||
|
num_input_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'input_blocks' in layer})
|
||||||
|
input_blocks = {layer_id: [key for key in unet_state_dict if f'input_blocks.{layer_id}' in key] for layer_id in range(num_input_blocks)}
|
||||||
|
|
||||||
|
# Retrieves the keys for the middle blocks only
|
||||||
|
num_middle_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'middle_block' in layer})
|
||||||
|
middle_blocks = {layer_id: [key for key in unet_state_dict if f'middle_block.{layer_id}' in key] for layer_id in range(num_middle_blocks)}
|
||||||
|
|
||||||
|
# Retrieves the keys for the output blocks only
|
||||||
|
num_output_blocks = len({'.'.join(layer.split('.')[:2]) for layer in unet_state_dict if 'output_blocks' in layer})
|
||||||
|
output_blocks = {layer_id: [key for key in unet_state_dict if f'output_blocks.{layer_id}' in key] for layer_id in range(num_output_blocks)}
|
||||||
|
|
||||||
|
for i in range(1, num_input_blocks):
|
||||||
|
block_id = (i - 1) // (config['layers_per_block'] + 1)
|
||||||
|
layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1)
|
||||||
|
|
||||||
|
resnets = [key for key in input_blocks[i] if f'input_blocks.{i}.0' in key and f'input_blocks.{i}.0.op' not in key]
|
||||||
|
attentions = [key for key in input_blocks[i] if f'input_blocks.{i}.1' in key]
|
||||||
|
|
||||||
|
if f'input_blocks.{i}.0.op.weight' in unet_state_dict:
|
||||||
|
new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.weight'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight')
|
||||||
|
new_checkpoint[f'down_blocks.{block_id}.downsamplers.0.conv.bias'] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias')
|
||||||
|
|
||||||
|
paths = renew_resnet_paths(resnets)
|
||||||
|
meta_path = {'old': f'input_blocks.{i}.0', 'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
if len(attentions):
|
||||||
|
paths = renew_attention_paths(attentions)
|
||||||
|
meta_path = {'old': f'input_blocks.{i}.1', 'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}'}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
|
||||||
|
resnet_0 = middle_blocks[0]
|
||||||
|
attentions = middle_blocks[1]
|
||||||
|
resnet_1 = middle_blocks[2]
|
||||||
|
|
||||||
|
resnet_0_paths = renew_resnet_paths(resnet_0)
|
||||||
|
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
||||||
|
|
||||||
|
resnet_1_paths = renew_resnet_paths(resnet_1)
|
||||||
|
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
||||||
|
|
||||||
|
attentions_paths = renew_attention_paths(attentions)
|
||||||
|
meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'}
|
||||||
|
assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
for i in range(num_output_blocks):
|
||||||
|
block_id = i // (config['layers_per_block'] + 1)
|
||||||
|
layer_in_block_id = i % (config['layers_per_block'] + 1)
|
||||||
|
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
||||||
|
output_block_list = {}
|
||||||
|
|
||||||
|
for layer in output_block_layers:
|
||||||
|
layer_id, layer_name = layer.split('.')[0], shave_segments(layer, 1)
|
||||||
|
if layer_id in output_block_list:
|
||||||
|
output_block_list[layer_id].append(layer_name)
|
||||||
|
else:
|
||||||
|
output_block_list[layer_id] = [layer_name]
|
||||||
|
|
||||||
|
if len(output_block_list) > 1:
|
||||||
|
resnets = [key for key in output_blocks[i] if f'output_blocks.{i}.0' in key]
|
||||||
|
attentions = [key for key in output_blocks[i] if f'output_blocks.{i}.1' in key]
|
||||||
|
|
||||||
|
resnet_0_paths = renew_resnet_paths(resnets)
|
||||||
|
paths = renew_resnet_paths(resnets)
|
||||||
|
|
||||||
|
meta_path = {'old': f'output_blocks.{i}.0', 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}'}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
if ['conv.weight', 'conv.bias'] in output_block_list.values():
|
||||||
|
index = list(output_block_list.values()).index(['conv.weight', 'conv.bias'])
|
||||||
|
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.weight'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight']
|
||||||
|
new_checkpoint[f'up_blocks.{block_id}.upsamplers.0.conv.bias'] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias']
|
||||||
|
|
||||||
|
# Clear attentions as they have been attributed above.
|
||||||
|
if len(attentions) == 2:
|
||||||
|
attentions = []
|
||||||
|
|
||||||
|
if len(attentions):
|
||||||
|
paths = renew_attention_paths(attentions)
|
||||||
|
meta_path = {
|
||||||
|
'old': f'output_blocks.{i}.1',
|
||||||
|
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}'
|
||||||
|
}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
else:
|
||||||
|
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
||||||
|
for path in resnet_0_paths:
|
||||||
|
old_path = '.'.join(['output_blocks', str(i), path['old']])
|
||||||
|
new_path = '.'.join(['up_blocks', str(block_id), 'resnets', str(layer_in_block_id), path['new']])
|
||||||
|
|
||||||
|
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||||
|
|
||||||
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||||
|
# extract state dict for VAE
|
||||||
|
vae_state_dict = {}
|
||||||
|
vae_key = "first_stage_model."
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith(vae_key):
|
||||||
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||||
|
|
||||||
|
new_checkpoint = {}
|
||||||
|
|
||||||
|
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||||
|
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
||||||
|
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
||||||
|
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
||||||
|
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
||||||
|
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
||||||
|
|
||||||
|
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
||||||
|
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
||||||
|
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
||||||
|
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
||||||
|
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
||||||
|
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
||||||
|
|
||||||
|
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
||||||
|
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
||||||
|
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
||||||
|
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
||||||
|
|
||||||
|
|
||||||
|
# Retrieves the keys for the encoder down blocks only
|
||||||
|
num_down_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'encoder.down' in layer})
|
||||||
|
down_blocks = {layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] for layer_id in range(num_down_blocks)}
|
||||||
|
|
||||||
|
# Retrieves the keys for the decoder up blocks only
|
||||||
|
num_up_blocks = len({'.'.join(layer.split('.')[:3]) for layer in vae_state_dict if 'decoder.up' in layer})
|
||||||
|
up_blocks = {layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] for layer_id in range(num_up_blocks)}
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(num_down_blocks):
|
||||||
|
resnets = [key for key in down_blocks[i] if f'down.{i}' in key and f"down.{i}.downsample" not in key]
|
||||||
|
|
||||||
|
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||||
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
|
||||||
|
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {'old': f'down.{i}.block', 'new': f'down_blocks.{i}.resnets'}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||||
|
num_mid_res_blocks = 2
|
||||||
|
for i in range(1, num_mid_res_blocks + 1):
|
||||||
|
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||||
|
paths = renew_vae_attention_paths(mid_attentions)
|
||||||
|
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
conv_attn_to_linear(new_checkpoint)
|
||||||
|
|
||||||
|
for i in range(num_up_blocks):
|
||||||
|
block_id = num_up_blocks - 1 - i
|
||||||
|
resnets = [key for key in up_blocks[block_id] if f'up.{block_id}' in key and f"up.{block_id}.upsample" not in key]
|
||||||
|
|
||||||
|
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||||
|
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
|
||||||
|
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {'old': f'up.{block_id}.block', 'new': f'up_blocks.{i}.resnets'}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||||
|
num_mid_res_blocks = 2
|
||||||
|
for i in range(1, num_mid_res_blocks + 1):
|
||||||
|
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||||
|
|
||||||
|
paths = renew_vae_resnet_paths(resnets)
|
||||||
|
meta_path = {'old': f'mid.block_{i}', 'new': f'mid_block.resnets.{i - 1}'}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
|
||||||
|
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||||
|
paths = renew_vae_attention_paths(mid_attentions)
|
||||||
|
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
|
||||||
|
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||||
|
conv_attn_to_linear(new_checkpoint)
|
||||||
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||||
|
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||||
|
|
||||||
|
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
||||||
|
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
||||||
|
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
||||||
|
|
||||||
|
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
||||||
|
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_linear(hf_linear, pt_linear):
|
||||||
|
hf_linear.weight = pt_linear.weight
|
||||||
|
hf_linear.bias = pt_linear.bias
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_layer(hf_layer, pt_layer):
|
||||||
|
# copy layer norms
|
||||||
|
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
||||||
|
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
||||||
|
|
||||||
|
# copy attn
|
||||||
|
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
||||||
|
|
||||||
|
# copy MLP
|
||||||
|
pt_mlp = pt_layer[1][1]
|
||||||
|
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
||||||
|
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_layers(hf_layers, pt_layers):
|
||||||
|
for i, hf_layer in enumerate(hf_layers):
|
||||||
|
if i != 0: i += i
|
||||||
|
pt_layer = pt_layers[i:i+2]
|
||||||
|
_copy_layer(hf_layer, pt_layer)
|
||||||
|
|
||||||
|
hf_model = LDMBertModel(config).eval()
|
||||||
|
|
||||||
|
# copy embeds
|
||||||
|
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
||||||
|
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
||||||
|
|
||||||
|
# copy layer norm
|
||||||
|
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
||||||
|
|
||||||
|
# copy hidden layers
|
||||||
|
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
||||||
|
|
||||||
|
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
||||||
|
|
||||||
|
return hf_model
|
||||||
|
|
||||||
|
def copy_ema_weights(checkpoint, config):
|
||||||
|
"""Copies ema weights over the original weights in a state_dict
|
||||||
|
Only applies to the unet
|
||||||
|
"""
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
for k, v in checkpoint.items():
|
||||||
|
if k.startswith('model.'):
|
||||||
|
model_key = k[6:]
|
||||||
|
ema_key = model.model_ema.m_name2s_name[model_key]
|
||||||
|
ema_weight = checkpoint["model_ema." + ema_key]
|
||||||
|
print(f"copying ema weight {ema_key} to {model_key}")
|
||||||
|
checkpoint[k] = ema_weight
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--original_config_file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The YAML config file corresponding to the original architecture.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dump_path", default=None, type=str, required=True, help="Path to the output model."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_ema", action="store_true", help="use EMA weights for conversion",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
original_config = OmegaConf.load(args.original_config_file)
|
||||||
|
|
||||||
|
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")["state_dict"]
|
||||||
|
|
||||||
|
if args.use_ema:
|
||||||
|
checkpoint = copy_ema_weights(checkpoint, original_config)
|
||||||
|
|
||||||
|
# Convert the UNet2DConditionModel model.
|
||||||
|
unet_config = create_unet_diffusers_config(original_config)
|
||||||
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
|
||||||
|
|
||||||
|
unet = UNet2DConditionModel(**unet_config)
|
||||||
|
unet.load_state_dict(converted_unet_checkpoint)
|
||||||
|
|
||||||
|
# Convert the VAE model.
|
||||||
|
vae_config = create_vae_diffusers_config(original_config)
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||||
|
|
||||||
|
vae = AutoencoderKL(**vae_config)
|
||||||
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
|
||||||
|
# Convert the text model.
|
||||||
|
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||||
|
if text_model_type == "FrozenCLIPEmbedder":
|
||||||
|
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
else:
|
||||||
|
# TODO: update the convert function to use the state_dict without the model instance.
|
||||||
|
text_config = create_ldm_bert_config(original_config)
|
||||||
|
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||||
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||||
|
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
scheduler = create_diffusers_schedular(original_config)
|
||||||
|
pipe = StableDiffusionPipeline(
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=text_model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=safety_checker,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
pipe.save_pretrained(args.dump_path)
|
||||||
|
|
Loading…
Reference in a new issue