stable-diffusion-finetune/paris_dataloader.py
2022-10-27 16:03:36 +02:00

97 lines
No EOL
3.6 KiB
Python

import math
from os import PathLike
import os
from pathlib import Path
from typing import Iterator, Optional
from ldm.data.base import Txt2ImgIterableBaseDataset
from PIL import Image, ImageOps
import numpy as np
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
from torch.utils.data import get_worker_info
from ldm.util import instantiate_from_config
from torchvision import transforms
from einops import rearrange
Image.init() # required to initialise PIL.Image.EXTENSION
def extract_labels(f: Path) -> dict:
# get the labels for the image path
arr = int(f.parent.parent.name.split("E")[0])
street = " ".join(f.parent.name.split(' ')[:-1]).replace('_', '\'')
return {
"arrondisement": arr,
"street": street,
}
# adapted from stylegan dataset_tool
def is_image_ext(f: Path) -> bool:
return f.suffix.lower() in Image.EXTENSION
def open_image_folder(source_dir, *, max_images: Optional[int] = None):
image_files = [f for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f) and f.parent.name != "BU"]
image_files = image_files[:max_images]
labeled_images = [{
"abspath": f.resolve(),
"relpath": f.relative_to(source_dir),
"labels": extract_labels(f.relative_to(source_dir))
} for f in image_files]
return labeled_images
def make_ordinal(n):
'''
Convert an integer into its ordinal representation::
make_ordinal(0) => '0th'
make_ordinal(3) => '3rd'
make_ordinal(122) => '122nd'
make_ordinal(213) => '213th'
'''
n = int(n)
if 11 <= (n % 100) <= 13:
suffix = 'th'
else:
suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
return str(n) + suffix
class ParisDataset(IterableDataset):
# see also ldm.utils.data.simple.FolderData
def __init__(self, image_folder: os.PathLike, width, height, max_images, image_transforms):
super(ParisDataset).__init__()
assert os.path.exists(image_folder), "image_folder does not exist"
self.labeled_images = open_image_folder(image_folder, max_images=max_images)
self.width = width
self.height = height
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend([transforms.ToTensor(),
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
image_transforms = transforms.Compose(image_transforms)
self.tform = image_transforms
def __iter__(self):
worker_info = get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = 0
iter_end = len(self.labeled_images)
else: # in a worker process
# split workload
per_worker = int(math.ceil(len(self.labeled_images) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, len(self.labeled_images))
for image in self.labeled_images[iter_start:iter_end]:
yield {
#with tform, the scaling is superfluous
'image': self.tform(ImageOps.fit(Image.open(image['abspath']), (self.width, self.height)).convert("RGB")),
'txt': f"Shop front on the {image['labels']['street']} in the {make_ordinal(image['labels']['arrondisement'])} arrondissement of Paris",
}
if __name__ == "__main__":
d = ParisDataset("../VLoD/", 512, 512)
for i in d:
print(i)