97 lines
No EOL
3.6 KiB
Python
97 lines
No EOL
3.6 KiB
Python
import math
|
|
from os import PathLike
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Iterator, Optional
|
|
from ldm.data.base import Txt2ImgIterableBaseDataset
|
|
from PIL import Image, ImageOps
|
|
import numpy as np
|
|
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
|
from torch.utils.data import get_worker_info
|
|
from ldm.util import instantiate_from_config
|
|
from torchvision import transforms
|
|
from einops import rearrange
|
|
|
|
Image.init() # required to initialise PIL.Image.EXTENSION
|
|
|
|
def extract_labels(f: Path) -> dict:
|
|
# get the labels for the image path
|
|
arr = int(f.parent.parent.name.split("E")[0])
|
|
street = " ".join(f.parent.name.split(' ')[:-1]).replace('_', '\'')
|
|
return {
|
|
"arrondisement": arr,
|
|
"street": street,
|
|
}
|
|
|
|
# adapted from stylegan dataset_tool
|
|
def is_image_ext(f: Path) -> bool:
|
|
return f.suffix.lower() in Image.EXTENSION
|
|
|
|
def open_image_folder(source_dir, *, max_images: Optional[int] = None):
|
|
image_files = [f for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f) and f.parent.name != "BU"]
|
|
image_files = image_files[:max_images]
|
|
|
|
labeled_images = [{
|
|
"abspath": f.resolve(),
|
|
"relpath": f.relative_to(source_dir),
|
|
"labels": extract_labels(f.relative_to(source_dir))
|
|
} for f in image_files]
|
|
|
|
return labeled_images
|
|
|
|
def make_ordinal(n):
|
|
'''
|
|
Convert an integer into its ordinal representation::
|
|
|
|
make_ordinal(0) => '0th'
|
|
make_ordinal(3) => '3rd'
|
|
make_ordinal(122) => '122nd'
|
|
make_ordinal(213) => '213th'
|
|
'''
|
|
n = int(n)
|
|
if 11 <= (n % 100) <= 13:
|
|
suffix = 'th'
|
|
else:
|
|
suffix = ['th', 'st', 'nd', 'rd', 'th'][min(n % 10, 4)]
|
|
return str(n) + suffix
|
|
|
|
|
|
class ParisDataset(IterableDataset):
|
|
# see also ldm.utils.data.simple.FolderData
|
|
def __init__(self, image_folder: os.PathLike, width, height, max_images, image_transforms):
|
|
super(ParisDataset).__init__()
|
|
assert os.path.exists(image_folder), "image_folder does not exist"
|
|
self.labeled_images = open_image_folder(image_folder, max_images=max_images)
|
|
self.width = width
|
|
self.height = height
|
|
|
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
|
image_transforms.extend([transforms.ToTensor(),
|
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
|
image_transforms = transforms.Compose(image_transforms)
|
|
self.tform = image_transforms
|
|
|
|
def __iter__(self):
|
|
worker_info = get_worker_info()
|
|
if worker_info is None: # single-process data loading, return the full iterator
|
|
iter_start = 0
|
|
iter_end = len(self.labeled_images)
|
|
else: # in a worker process
|
|
# split workload
|
|
per_worker = int(math.ceil(len(self.labeled_images) / float(worker_info.num_workers)))
|
|
worker_id = worker_info.id
|
|
iter_start = worker_id * per_worker
|
|
iter_end = min(iter_start + per_worker, len(self.labeled_images))
|
|
|
|
for image in self.labeled_images[iter_start:iter_end]:
|
|
yield {
|
|
#with tform, the scaling is superfluous
|
|
'image': self.tform(ImageOps.fit(Image.open(image['abspath']), (self.width, self.height)).convert("RGB")),
|
|
'txt': f"Shop front on the {image['labels']['street']} in the {make_ordinal(image['labels']['arrondisement'])} arrondissement of Paris",
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
d = ParisDataset("../VLoD/", 512, 512)
|
|
for i in d:
|
|
print(i) |