diff --git a/environment.yaml b/environment.yaml index f36b0e1..3a8e42e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -22,6 +22,7 @@ dependencies: - einops==0.3.0 - torch-fidelity==0.3.0 - transformers==4.3.1 + - webdataset==0.2.5 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/openai/CLIP.git@main#egg=clip - - -e . \ No newline at end of file + - -e . diff --git a/ldm/data/laion.py b/ldm/data/laion.py new file mode 100644 index 0000000..5d87d5d --- /dev/null +++ b/ldm/data/laion.py @@ -0,0 +1,36 @@ +import webdataset as wds +from PIL import Image +import io +import os +from tqdm import tqdm + +if __name__ == "__main__": + url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -" + dataset = wds.WebDataset(url) + example = next(iter(dataset)) + for k in example: + print(k, type(example[k])) + + print(example["__key__"]) + for k in ["json", "txt"]: + print(example[k].decode()) + + image = Image.open(io.BytesIO(example["jpg"])) + outdir = "tmp" + os.makedirs(outdir, exist_ok=True) + image.save(os.path.join(outdir, example["__key__"]+".png")) + + + def load_example(example): + return { + "key": example["__key__"], + "image": Image.open(io.BytesIO(example["jpg"])), + "text": example["txt"].decode(), + } + + + for i, example in tqdm(enumerate(dataset)): + ex = load_example(example) + print(ex["image"].size, ex["text"]) + if i >= 100: + break