f1293f9795
* simple datasets * add conversion script * finish fine tune example * update readme * update readme
101 lines
No EOL
3.6 KiB
Python
101 lines
No EOL
3.6 KiB
Python
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from pathlib import Path
|
|
import json
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from einops import rearrange
|
|
from ldm.util import instantiate_from_config
|
|
from datasets import load_dataset
|
|
|
|
class FolderData(Dataset):
|
|
def __init__(self, root_dir, caption_file, image_transforms, ext="jpg") -> None:
|
|
self.root_dir = Path(root_dir)
|
|
with open(caption_file, "rt") as f:
|
|
captions = json.load(f)
|
|
self.captions = captions
|
|
|
|
self.paths = list(self.root_dir.rglob(f"*.{ext}"))
|
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
|
image_transforms.extend([transforms.ToTensor(),
|
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
|
image_transforms = transforms.Compose(image_transforms)
|
|
self.tform = image_transforms
|
|
|
|
# assert all(['full/' + str(x.name) in self.captions for x in self.paths])
|
|
|
|
def __len__(self):
|
|
return len(self.captions.keys())
|
|
|
|
def __getitem__(self, index):
|
|
chosen = list(self.captions.keys())[index]
|
|
im = Image.open(self.root_dir/chosen)
|
|
im = self.process_im(im)
|
|
caption = self.captions[chosen]
|
|
if caption is None:
|
|
caption = "old book illustration"
|
|
return {"jpg": im, "txt": caption}
|
|
|
|
def process_im(self, im):
|
|
im = im.convert("RGB")
|
|
return self.tform(im)
|
|
|
|
def hf_dataset(
|
|
name,
|
|
image_transforms=[],
|
|
image_column="image",
|
|
text_column="text",
|
|
split='train',
|
|
image_key='image',
|
|
caption_key='txt',
|
|
):
|
|
"""Make huggingface dataset with appropriate list of transforms applied
|
|
"""
|
|
ds = load_dataset(name, split=split)
|
|
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
|
image_transforms.extend([transforms.ToTensor(),
|
|
transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
|
|
tform = transforms.Compose(image_transforms)
|
|
|
|
assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
|
|
assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
|
|
|
|
def pre_process(examples):
|
|
processed = {}
|
|
processed[image_key] = [tform(im) for im in examples[image_column]]
|
|
processed[caption_key] = examples[text_column]
|
|
return processed
|
|
|
|
ds.set_transform(pre_process)
|
|
return ds
|
|
|
|
class TextOnly(Dataset):
|
|
def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
|
|
"""Returns only captions with dummy images"""
|
|
self.output_size = output_size
|
|
self.image_key = image_key
|
|
self.caption_key = caption_key
|
|
if isinstance(captions, Path):
|
|
self.captions = self._load_caption_file(captions)
|
|
else:
|
|
self.captions = captions
|
|
|
|
if n_gpus > 1:
|
|
# hack to make sure that all the captions appear on each gpu
|
|
repeated = [n_gpus*[x] for x in self.captions]
|
|
self.captions = []
|
|
[self.captions.extend(x) for x in repeated]
|
|
|
|
def __len__(self):
|
|
return len(self.captions)
|
|
|
|
def __getitem__(self, index):
|
|
dummy_im = torch.zeros(3, self.output_size, self.output_size)
|
|
dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
|
|
return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
|
|
|
|
def _load_caption_file(self, filename):
|
|
with open(filename, 'rt') as f:
|
|
captions = f.readlines()
|
|
return [x.strip('\n') for x in captions] |