Additions to fine-tune TPDE dataset

This commit is contained in:
Ruben van de Ven 2022-10-27 16:03:36 +02:00
parent 3a64aae085
commit b3deaeb181
4 changed files with 239 additions and 2 deletions

View File

@ -1,6 +1,8 @@
# Experiments with Stable Diffusion
This repository extends and adds to the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion.
Fine-tune stable diffusion with images from Paris. Part of the _This Place Does Exist_ experiment.
This repository extends and adds to the [fine-tune repo by Justin Pinkney](https//github.com/justinpinkney/stable-diffusion.git) which in turn extends the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion.
Currently it adds:

View File

@ -0,0 +1,138 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "image"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
scale_factor: 0.18215
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]
f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 4
num_workers: 4
use_worker_init_fn: false
# num_val_workers: 0 # Avoid a weird val dataloader issue
train:
target: paris_dataloader.ParisDataset #implements https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
params:
image_folder: ../VLoD/
width: 512
height: 512
max_images: 50
image_transforms:
- target: torchvision.transforms.Resize
params:
size: 512
interpolation: 3
- target: torchvision.transforms.RandomCrop
params:
size: 512
- target: torchvision.transforms.RandomHorizontalFlip
validation:
target: ldm.data.simple.TextOnly
params:
captions:
- "The 1st arrondisement of Paris"
- "The 14th arrondisement of Paris"
- "A shop front"
- "Boulevard de Magenta"
output_size: 512
n_gpus: 2 # small hack to sure we see all our samples
lightning:
find_unused_parameters: False
modelcheckpoint:
params:
every_n_train_steps: 2000
save_top_k: -1
monitor: null
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 2000
max_images: 4
increase_log_steps: False
log_first_step: True
log_all_val: True
log_images_kwargs:
use_ema_scope: True
inpaint: False
plot_progressive_rows: False
plot_diffusion_rows: False
N: 4
unconditional_guidance_scale: 3.0
unconditional_guidance_label: [""]
trainer:
benchmark: True
num_sanity_val_steps: 0
accumulate_grad_batches: 1
# gpus: "1," # for some reason triggers erro

View File

@ -215,7 +215,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
else:
init_fn = None
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
num_workers=self.num_workers, # shuffle=False if is_iterable_dataset else True,
worker_init_fn=init_fn)
def _val_dataloader(self, shuffle=False):

97
paris_dataloader.py Normal file
View File

@ -0,0 +1,97 @@
import math
from os import PathLike
import os
from pathlib import Path
from typing import Iterator, Optional
from ldm.data.base import Txt2ImgIterableBaseDataset
from PIL import Image, ImageOps
import numpy as np
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
from torch.utils.data import get_worker_info
from ldm.util import instantiate_from_config
from torchvision import transforms
from einops import rearrange
Image.init() # required to initialise PIL.Image.EXTENSION
def extract_labels(f: Path) -> dict:
# get the labels for the image path
arr = int(f.parent.parent.name.split("E")[0])
street = " ".join(f.parent.name.split(' ')[:-1]).replace('_', '\'')
return {
"arrondisement": arr,
"street": street,
}
# adapted from stylegan dataset_tool
def is_image_ext(f: Path) -> bool:
return f.suffix.lower() in Image.EXTENSION
def open_image_folder(source_dir, *, max_images: Optional[int] = None):
image_files = [f for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f) and f.parent.name != "BU"]
image_files = image_files[:max_images]
labeled_images = [{
"abspath": f.resolve(),
"relpath": f.relative_to(source_dir),
"labels": extract_labels(f.relative_to(source_dir))
} for f in image_files]
return labeled_images
def make_ordinal(n):
'''
Convert an integer into its ordinal representation::
make_ordinal(0) => '0th'
make_ordinal(3) => '3rd'
make_ordinal(122) => '122nd'
make_ordinal(213) => '213th'
'''
n = int(n)
if 11 <= (n % 100) <= 13:
suffix = 'th'
else:
suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
return str(n) + suffix
class ParisDataset(IterableDataset):
# see also ldm.utils.data.simple.FolderData
def __init__(self, image_folder: os.PathLike, width, height, max_images, image_transforms):
super(ParisDataset).__init__()
assert os.path.exists(image_folder), "image_folder does not exist"
self.labeled_images = open_image_folder(image_folder, max_images=max_images)
self.width = width
self.height = height
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
image_transforms = transforms.Compose(image_transforms)
self.tform = image_transforms
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = 0
iter_end = len(self.labeled_images)
else: # in a worker process
# split workload
per_worker = int(math.ceil(len(self.labeled_images) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, len(self.labeled_images))
for image in self.labeled_images[iter_start:iter_end]:
yield {
#with tform, the scaling is superfluous
'image': self.tform(ImageOps.fit(Image.open(image['abspath']), (self.width, self.height)).convert("RGB")),
'txt': f"Shop front on the {image['labels']['street']} in the {make_ordinal(image['labels']['arrondisement'])} arrondissement of Paris",
}
if __name__ == "__main__":
d = ParisDataset("../VLoD/", 512, 512)
for i in d:
print(i)