From b3deaeb181a9c15a377b92df92b2c73bc58ca69c Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Thu, 27 Oct 2022 16:03:36 +0200 Subject: [PATCH] Additions to fine-tune TPDE dataset --- README.md | 4 +- configs/stable-diffusion/paris.yaml | 138 ++++++++++++++++++++++++++++ main.py | 2 +- paris_dataloader.py | 97 +++++++++++++++++++ 4 files changed, 239 insertions(+), 2 deletions(-) create mode 100644 configs/stable-diffusion/paris.yaml create mode 100644 paris_dataloader.py diff --git a/README.md b/README.md index 4969a48..543ecb8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # Experiments with Stable Diffusion -This repository extends and adds to the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion. +Fine-tune stable diffusion with images from Paris. Part of the _This Place Does Exist_ experiment. + +This repository extends and adds to the [fine-tune repo by Justin Pinkney](https//github.com/justinpinkney/stable-diffusion.git) which in turn extends the [original training repo](https://github.com/pesser/stable-diffusion) for Stable Diffusion. Currently it adds: diff --git a/configs/stable-diffusion/paris.yaml b/configs/stable-diffusion/paris.yaml new file mode 100644 index 0000000..0ee72a0 --- /dev/null +++ b/configs/stable-diffusion/paris.yaml @@ -0,0 +1,138 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "image" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + scale_factor: 0.18215 + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + + +data: + target: main.DataModuleFromConfig + params: + batch_size: 4 + num_workers: 4 + use_worker_init_fn: false + # num_val_workers: 0 # Avoid a weird val dataloader issue + train: + target: paris_dataloader.ParisDataset #implements https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset + params: + image_folder: ../VLoD/ + width: 512 + height: 512 + max_images: 50 + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + - target: torchvision.transforms.RandomHorizontalFlip + validation: + target: ldm.data.simple.TextOnly + params: + captions: + - "The 1st arrondisement of Paris" + - "The 14th arrondisement of Paris" + - "A shop front" + - "Boulevard de Magenta" + output_size: 512 + n_gpus: 2 # small hack to sure we see all our samples + + +lightning: + find_unused_parameters: False + + modelcheckpoint: + params: + every_n_train_steps: 2000 + save_top_k: -1 + monitor: null + + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 2000 + max_images: 4 + increase_log_steps: False + log_first_step: True + log_all_val: True + log_images_kwargs: + use_ema_scope: True + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: [""] + + trainer: + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + # gpus: "1," # for some reason triggers erro diff --git a/main.py b/main.py index 8a70313..40af9a0 100644 --- a/main.py +++ b/main.py @@ -215,7 +215,7 @@ class DataModuleFromConfig(pl.LightningDataModule): else: init_fn = None return DataLoader(self.datasets["train"], batch_size=self.batch_size, - num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, + num_workers=self.num_workers, # shuffle=False if is_iterable_dataset else True, worker_init_fn=init_fn) def _val_dataloader(self, shuffle=False): diff --git a/paris_dataloader.py b/paris_dataloader.py new file mode 100644 index 0000000..617609d --- /dev/null +++ b/paris_dataloader.py @@ -0,0 +1,97 @@ +import math +from os import PathLike +import os +from pathlib import Path +from typing import Iterator, Optional +from ldm.data.base import Txt2ImgIterableBaseDataset +from PIL import Image, ImageOps +import numpy as np +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset +from torch.utils.data import get_worker_info +from ldm.util import instantiate_from_config +from torchvision import transforms +from einops import rearrange + +Image.init() # required to initialise PIL.Image.EXTENSION + +def extract_labels(f: Path) -> dict: + # get the labels for the image path + arr = int(f.parent.parent.name.split("E")[0]) + street = " ".join(f.parent.name.split(' ')[:-1]).replace('_', '\'') + return { + "arrondisement": arr, + "street": street, + } + +# adapted from stylegan dataset_tool +def is_image_ext(f: Path) -> bool: + return f.suffix.lower() in Image.EXTENSION + +def open_image_folder(source_dir, *, max_images: Optional[int] = None): + image_files = [f for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f) and f.parent.name != "BU"] + image_files = image_files[:max_images] + + labeled_images = [{ + "abspath": f.resolve(), + "relpath": f.relative_to(source_dir), + "labels": extract_labels(f.relative_to(source_dir)) + } for f in image_files] + + return labeled_images + +def make_ordinal(n): + ''' + Convert an integer into its ordinal representation:: + + make_ordinal(0) => '0th' + make_ordinal(3) => '3rd' + make_ordinal(122) => '122nd' + make_ordinal(213) => '213th' + ''' + n = int(n) + if 11 <= (n % 100) <= 13: + suffix = 'th' + else: + suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)] + return str(n) + suffix + + +class ParisDataset(IterableDataset): + # see also ldm.utils.data.simple.FolderData + def __init__(self, image_folder: os.PathLike, width, height, max_images, image_transforms): + super(ParisDataset).__init__() + assert os.path.exists(image_folder), "image_folder does not exist" + self.labeled_images = open_image_folder(image_folder, max_images=max_images) + self.width = width + self.height = height + + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + + def __iter__(self): + worker_info = get_worker_info() + if worker_info is None: # single-process data loading, return the full iterator + iter_start = 0 + iter_end = len(self.labeled_images) + else: # in a worker process + # split workload + per_worker = int(math.ceil(len(self.labeled_images) / float(worker_info.num_workers))) + worker_id = worker_info.id + iter_start = worker_id * per_worker + iter_end = min(iter_start + per_worker, len(self.labeled_images)) + + for image in self.labeled_images[iter_start:iter_end]: + yield { + #with tform, the scaling is superfluous + 'image': self.tform(ImageOps.fit(Image.open(image['abspath']), (self.width, self.height)).convert("RGB")), + 'txt': f"Shop front on the {image['labels']['street']} in the {make_ordinal(image['labels']['arrondisement'])} arrondissement of Paris", + } + + +if __name__ == "__main__": + d = ParisDataset("../VLoD/", 512, 512) + for i in d: + print(i) \ No newline at end of file