Additions to fine-tune TPDE dataset
This commit is contained in:
parent
3a64aae085
commit
b3deaeb181
4 changed files with 239 additions and 2 deletions
|
@ -1,6 +1,8 @@
|
|||
# Experiments with Stable Diffusion
|
||||
|
||||
This repository extends and adds to the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion.
|
||||
Fine-tune stable diffusion with images from Paris. Part of the _This Place Does Exist_ experiment.
|
||||
|
||||
This repository extends and adds to the [fine-tune repo by Justin Pinkney](https//github.com/justinpinkney/stable-diffusion.git) which in turn extends the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion.
|
||||
|
||||
Currently it adds:
|
||||
|
||||
|
|
138
configs/stable-diffusion/paris.yaml
Normal file
138
configs/stable-diffusion/paris.yaml
Normal file
|
@ -0,0 +1,138 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "image"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
scale_factor: 0.18215
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
use_worker_init_fn: false
|
||||
# num_val_workers: 0 # Avoid a weird val dataloader issue
|
||||
train:
|
||||
target: paris_dataloader.ParisDataset #implements https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
|
||||
params:
|
||||
image_folder: ../VLoD/
|
||||
width: 512
|
||||
height: 512
|
||||
max_images: 50
|
||||
image_transforms:
|
||||
- target: torchvision.transforms.Resize
|
||||
params:
|
||||
size: 512
|
||||
interpolation: 3
|
||||
- target: torchvision.transforms.RandomCrop
|
||||
params:
|
||||
size: 512
|
||||
- target: torchvision.transforms.RandomHorizontalFlip
|
||||
validation:
|
||||
target: ldm.data.simple.TextOnly
|
||||
params:
|
||||
captions:
|
||||
- "The 1st arrondisement of Paris"
|
||||
- "The 14th arrondisement of Paris"
|
||||
- "A shop front"
|
||||
- "Boulevard de Magenta"
|
||||
output_size: 512
|
||||
n_gpus: 2 # small hack to sure we see all our samples
|
||||
|
||||
|
||||
lightning:
|
||||
find_unused_parameters: False
|
||||
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 2000
|
||||
save_top_k: -1
|
||||
monitor: null
|
||||
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 2000
|
||||
max_images: 4
|
||||
increase_log_steps: False
|
||||
log_first_step: True
|
||||
log_all_val: True
|
||||
log_images_kwargs:
|
||||
use_ema_scope: True
|
||||
inpaint: False
|
||||
plot_progressive_rows: False
|
||||
plot_diffusion_rows: False
|
||||
N: 4
|
||||
unconditional_guidance_scale: 3.0
|
||||
unconditional_guidance_label: [""]
|
||||
|
||||
trainer:
|
||||
benchmark: True
|
||||
num_sanity_val_steps: 0
|
||||
accumulate_grad_batches: 1
|
||||
# gpus: "1," # for some reason triggers erro
|
2
main.py
2
main.py
|
@ -215,7 +215,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
|||
else:
|
||||
init_fn = None
|
||||
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
||||
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
||||
num_workers=self.num_workers, # shuffle=False if is_iterable_dataset else True,
|
||||
worker_init_fn=init_fn)
|
||||
|
||||
def _val_dataloader(self, shuffle=False):
|
||||
|
|
97
paris_dataloader.py
Normal file
97
paris_dataloader.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
import math
|
||||
from os import PathLike
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||
from PIL import Image, ImageOps
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
||||
from torch.utils.data import get_worker_info
|
||||
from ldm.util import instantiate_from_config
|
||||
from torchvision import transforms
|
||||
from einops import rearrange
|
||||
|
||||
Image.init() # required to initialise PIL.Image.EXTENSION
|
||||
|
||||
def extract_labels(f: Path) -> dict:
|
||||
# get the labels for the image path
|
||||
arr = int(f.parent.parent.name.split("E")[0])
|
||||
street = " ".join(f.parent.name.split(' ')[:-1]).replace('_', '\'')
|
||||
return {
|
||||
"arrondisement": arr,
|
||||
"street": street,
|
||||
}
|
||||
|
||||
# adapted from stylegan dataset_tool
|
||||
def is_image_ext(f: Path) -> bool:
|
||||
return f.suffix.lower() in Image.EXTENSION
|
||||
|
||||
def open_image_folder(source_dir, *, max_images: Optional[int] = None):
|
||||
image_files = [f for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f) and f.parent.name != "BU"]
|
||||
image_files = image_files[:max_images]
|
||||
|
||||
labeled_images = [{
|
||||
"abspath": f.resolve(),
|
||||
"relpath": f.relative_to(source_dir),
|
||||
"labels": extract_labels(f.relative_to(source_dir))
|
||||
} for f in image_files]
|
||||
|
||||
return labeled_images
|
||||
|
||||
def make_ordinal(n):
|
||||
'''
|
||||
Convert an integer into its ordinal representation::
|
||||
|
||||
make_ordinal(0) => '0th'
|
||||
make_ordinal(3) => '3rd'
|
||||
make_ordinal(122) => '122nd'
|
||||
make_ordinal(213) => '213th'
|
||||
'''
|
||||
n = int(n)
|
||||
if 11 <= (n % 100) <= 13:
|
||||
suffix = 'th'
|
||||
else:
|
||||
suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
|
||||
return str(n) + suffix
|
||||
|
||||
|
||||
class ParisDataset(IterableDataset):
|
||||
# see also ldm.utils.data.simple.FolderData
|
||||
def __init__(self, image_folder: os.PathLike, width, height, max_images, image_transforms):
|
||||
super(ParisDataset).__init__()
|
||||
assert os.path.exists(image_folder), "image_folder does not exist"
|
||||
self.labeled_images = open_image_folder(image_folder, max_images=max_images)
|
||||
self.width = width
|
||||
self.height = height
|
||||
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend([transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||
image_transforms = transforms.Compose(image_transforms)
|
||||
self.tform = image_transforms
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = get_worker_info()
|
||||
if worker_info is None: # single-process data loading, return the full iterator
|
||||
iter_start = 0
|
||||
iter_end = len(self.labeled_images)
|
||||
else: # in a worker process
|
||||
# split workload
|
||||
per_worker = int(math.ceil(len(self.labeled_images) / float(worker_info.num_workers)))
|
||||
worker_id = worker_info.id
|
||||
iter_start = worker_id * per_worker
|
||||
iter_end = min(iter_start + per_worker, len(self.labeled_images))
|
||||
|
||||
for image in self.labeled_images[iter_start:iter_end]:
|
||||
yield {
|
||||
#with tform, the scaling is superfluous
|
||||
'image': self.tform(ImageOps.fit(Image.open(image['abspath']), (self.width, self.height)).convert("RGB")),
|
||||
'txt': f"Shop front on the {image['labels']['street']} in the {make_ordinal(image['labels']['arrondisement'])} arrondissement of Paris",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
d = ParisDataset("../VLoD/", 512, 512)
|
||||
for i in d:
|
||||
print(i)
|
Loading…
Reference in a new issue