import math from os import PathLike import os from pathlib import Path from typing import Iterator, Optional from ldm.data.base import Txt2ImgIterableBaseDataset from PIL import Image, ImageOps import numpy as np from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset from torch.utils.data import get_worker_info from ldm.util import instantiate_from_config from torchvision import transforms from einops import rearrange Image.init() # required to initialise PIL.Image.EXTENSION def extract_labels(f: Path) -> dict: # get the labels for the image path arr = int(f.parent.parent.name.split("E")[0]) street = " ".join(f.parent.name.split(' ')[:-1]).replace('_', '\'') return { "arrondisement": arr, "street": street, } # adapted from stylegan dataset_tool def is_image_ext(f: Path) -> bool: return f.suffix.lower() in Image.EXTENSION def open_image_folder(source_dir, *, max_images: Optional[int] = None): image_files = [f for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f) and f.parent.name != "BU"] image_files = image_files[:max_images] labeled_images = [{ "abspath": f.resolve(), "relpath": f.relative_to(source_dir), "labels": extract_labels(f.relative_to(source_dir)) } for f in image_files] return labeled_images def make_ordinal(n): ''' Convert an integer into its ordinal representation:: make_ordinal(0) => '0th' make_ordinal(3) => '3rd' make_ordinal(122) => '122nd' make_ordinal(213) => '213th' ''' n = int(n) if 11 <= (n % 100) <= 13: suffix = 'th' else: suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)] return str(n) + suffix class ParisDataset(IterableDataset): # see also ldm.utils.data.simple.FolderData def __init__(self, image_folder: os.PathLike, width, height, max_images, image_transforms): super(ParisDataset).__init__() assert os.path.exists(image_folder), "image_folder does not exist" self.labeled_images = open_image_folder(image_folder, max_images=max_images) self.width = width self.height = height image_transforms = [instantiate_from_config(tt) for tt in image_transforms] image_transforms.extend([transforms.ToTensor(), transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms def __iter__(self): worker_info = get_worker_info() if worker_info is None: # single-process data loading, return the full iterator iter_start = 0 iter_end = len(self.labeled_images) else: # in a worker process # split workload per_worker = int(math.ceil(len(self.labeled_images) / float(worker_info.num_workers))) worker_id = worker_info.id iter_start = worker_id * per_worker iter_end = min(iter_start + per_worker, len(self.labeled_images)) for image in self.labeled_images[iter_start:iter_end]: yield { #with tform, the scaling is superfluous 'image': self.tform(ImageOps.fit(Image.open(image['abspath']), (self.width, self.height)).convert("RGB")), 'txt': f"Shop front on the {image['labels']['street']} in the {make_ordinal(image['labels']['arrondisement'])} arrondissement of Paris", } if __name__ == "__main__": d = ParisDataset("../VLoD/", 512, 512) for i in d: print(i)