Additions to fine-tune TPDE dataset
This commit is contained in:
parent
3a64aae085
commit
b3deaeb181
4 changed files with 239 additions and 2 deletions
|
@ -1,6 +1,8 @@
|
||||||
# Experiments with Stable Diffusion
|
# Experiments with Stable Diffusion
|
||||||
|
|
||||||
This repository extends and adds to the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion.
|
Fine-tune stable diffusion with images from Paris. Part of the _This Place Does Exist_ experiment.
|
||||||
|
|
||||||
|
This repository extends and adds to the [fine-tune repo by Justin Pinkney](https//github.com/justinpinkney/stable-diffusion.git) which in turn extends the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion.
|
||||||
|
|
||||||
Currently it adds:
|
Currently it adds:
|
||||||
|
|
||||||
|
|
138
configs/stable-diffusion/paris.yaml
Normal file
138
configs/stable-diffusion/paris.yaml
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "image"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
use_worker_init_fn: false
|
||||||
|
# num_val_workers: 0 # Avoid a weird val dataloader issue
|
||||||
|
train:
|
||||||
|
target: paris_dataloader.ParisDataset #implements https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
|
||||||
|
params:
|
||||||
|
image_folder: ../VLoD/
|
||||||
|
width: 512
|
||||||
|
height: 512
|
||||||
|
max_images: 50
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
- target: torchvision.transforms.RandomHorizontalFlip
|
||||||
|
validation:
|
||||||
|
target: ldm.data.simple.TextOnly
|
||||||
|
params:
|
||||||
|
captions:
|
||||||
|
- "The 1st arrondisement of Paris"
|
||||||
|
- "The 14th arrondisement of Paris"
|
||||||
|
- "A shop front"
|
||||||
|
- "Boulevard de Magenta"
|
||||||
|
output_size: 512
|
||||||
|
n_gpus: 2 # small hack to sure we see all our samples
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: False
|
||||||
|
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 2000
|
||||||
|
save_top_k: -1
|
||||||
|
monitor: null
|
||||||
|
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 2000
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
log_all_val: True
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: True
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
||||||
|
# gpus: "1," # for some reason triggers erro
|
2
main.py
2
main.py
|
@ -215,7 +215,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
|
||||||
else:
|
else:
|
||||||
init_fn = None
|
init_fn = None
|
||||||
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
||||||
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
num_workers=self.num_workers, # shuffle=False if is_iterable_dataset else True,
|
||||||
worker_init_fn=init_fn)
|
worker_init_fn=init_fn)
|
||||||
|
|
||||||
def _val_dataloader(self, shuffle=False):
|
def _val_dataloader(self, shuffle=False):
|
||||||
|
|
97
paris_dataloader.py
Normal file
97
paris_dataloader.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
import math
|
||||||
|
from os import PathLike
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator, Optional
|
||||||
|
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
||||||
|
from torch.utils.data import get_worker_info
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from torchvision import transforms
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
Image.init() # required to initialise PIL.Image.EXTENSION
|
||||||
|
|
||||||
|
def extract_labels(f: Path) -> dict:
|
||||||
|
# get the labels for the image path
|
||||||
|
arr = int(f.parent.parent.name.split("E")[0])
|
||||||
|
street = " ".join(f.parent.name.split(' ')[:-1]).replace('_', '\'')
|
||||||
|
return {
|
||||||
|
"arrondisement": arr,
|
||||||
|
"street": street,
|
||||||
|
}
|
||||||
|
|
||||||
|
# adapted from stylegan dataset_tool
|
||||||
|
def is_image_ext(f: Path) -> bool:
|
||||||
|
return f.suffix.lower() in Image.EXTENSION
|
||||||
|
|
||||||
|
def open_image_folder(source_dir, *, max_images: Optional[int] = None):
|
||||||
|
image_files = [f for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f) and f.parent.name != "BU"]
|
||||||
|
image_files = image_files[:max_images]
|
||||||
|
|
||||||
|
labeled_images = [{
|
||||||
|
"abspath": f.resolve(),
|
||||||
|
"relpath": f.relative_to(source_dir),
|
||||||
|
"labels": extract_labels(f.relative_to(source_dir))
|
||||||
|
} for f in image_files]
|
||||||
|
|
||||||
|
return labeled_images
|
||||||
|
|
||||||
|
def make_ordinal(n):
|
||||||
|
'''
|
||||||
|
Convert an integer into its ordinal representation::
|
||||||
|
|
||||||
|
make_ordinal(0) => '0th'
|
||||||
|
make_ordinal(3) => '3rd'
|
||||||
|
make_ordinal(122) => '122nd'
|
||||||
|
make_ordinal(213) => '213th'
|
||||||
|
'''
|
||||||
|
n = int(n)
|
||||||
|
if 11 <= (n % 100) <= 13:
|
||||||
|
suffix = 'th'
|
||||||
|
else:
|
||||||
|
suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
|
||||||
|
return str(n) + suffix
|
||||||
|
|
||||||
|
|
||||||
|
class ParisDataset(IterableDataset):
|
||||||
|
# see also ldm.utils.data.simple.FolderData
|
||||||
|
def __init__(self, image_folder: os.PathLike, width, height, max_images, image_transforms):
|
||||||
|
super(ParisDataset).__init__()
|
||||||
|
assert os.path.exists(image_folder), "image_folder does not exist"
|
||||||
|
self.labeled_images = open_image_folder(image_folder, max_images=max_images)
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
|
||||||
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||||
|
image_transforms.extend([transforms.ToTensor(),
|
||||||
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
||||||
|
image_transforms = transforms.Compose(image_transforms)
|
||||||
|
self.tform = image_transforms
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
worker_info = get_worker_info()
|
||||||
|
if worker_info is None: # single-process data loading, return the full iterator
|
||||||
|
iter_start = 0
|
||||||
|
iter_end = len(self.labeled_images)
|
||||||
|
else: # in a worker process
|
||||||
|
# split workload
|
||||||
|
per_worker = int(math.ceil(len(self.labeled_images) / float(worker_info.num_workers)))
|
||||||
|
worker_id = worker_info.id
|
||||||
|
iter_start = worker_id * per_worker
|
||||||
|
iter_end = min(iter_start + per_worker, len(self.labeled_images))
|
||||||
|
|
||||||
|
for image in self.labeled_images[iter_start:iter_end]:
|
||||||
|
yield {
|
||||||
|
#with tform, the scaling is superfluous
|
||||||
|
'image': self.tform(ImageOps.fit(Image.open(image['abspath']), (self.width, self.height)).convert("RGB")),
|
||||||
|
'txt': f"Shop front on the {image['labels']['street']} in the {make_ordinal(image['labels']['arrondisement'])} arrondissement of Paris",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
d = ParisDataset("../VLoD/", 512, 512)
|
||||||
|
for i in d:
|
||||||
|
print(i)
|
Loading…
Reference in a new issue