laion webdataset example code

This commit is contained in:
pesser 2022-05-26 21:51:29 +00:00
parent 5a6571e384
commit f7a6152022
2 changed files with 38 additions and 1 deletions

View File

@ -22,6 +22,7 @@ dependencies:
- einops==0.3.0
- torch-fidelity==0.3.0
- transformers==4.3.1
- webdataset==0.2.5
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e .
- -e .

36
ldm/data/laion.py Normal file
View File

@ -0,0 +1,36 @@
import webdataset as wds
from PIL import Image
import io
import os
from tqdm import tqdm
if __name__ == "__main__":
url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
dataset = wds.WebDataset(url)
example = next(iter(dataset))
for k in example:
print(k, type(example[k]))
print(example["__key__"])
for k in ["json", "txt"]:
print(example[k].decode())
image = Image.open(io.BytesIO(example["jpg"]))
outdir = "tmp"
os.makedirs(outdir, exist_ok=True)
image.save(os.path.join(outdir, example["__key__"]+".png"))
def load_example(example):
return {
"key": example["__key__"],
"image": Image.open(io.BytesIO(example["jpg"])),
"text": example["txt"].decode(),
}
for i, example in tqdm(enumerate(dataset)):
ex = load_example(example)
print(ex["image"].size, ex["text"])
if i >= 100:
break