From d9c9747122010c1d9214eb15b1f564cf93d14d0a Mon Sep 17 00:00:00 2001 From: Patrick Esser Date: Fri, 22 Jul 2022 09:50:01 +0000 Subject: [PATCH] laion explorations --- ldm/data/laion.py | 70 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/ldm/data/laion.py b/ldm/data/laion.py index 41d6fa3..07c948e 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -300,8 +300,7 @@ def example01(): print("next epoch.") -if __name__ == "__main__": - #example01() +def example02(): from omegaconf import OmegaConf from torch.utils.data.distributed import DistributedSampler from torch.utils.data import IterableDataset @@ -318,3 +317,70 @@ if __name__ == "__main__": print(batch.keys()) print(batch["jpg"].shape) break + + +def example03(): + # improved aesthetics + tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -" + dataset = wds.WebDataset(tars) + + def filter_keys(x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + + def filter_size(x): + try: + return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512 + except Exception: + return False + + dataset = (dataset + .select(filter_keys) + .decode('pil', handler=wds.warn_and_continue)) + n_total = 0 + n_large = 0 + for i, example in enumerate(dataset): + n_total += 1 + if filter_size(example): + n_large += 1 + + if i%1000 == 0: + print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%") + + + +def example04(): + # improved aesthetics + for i_shard in range(60208)[::-1]: + print(i_shard) + tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard) + dataset = wds.WebDataset(tars) + + def filter_keys(x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + + def filter_size(x): + try: + return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512 + except Exception: + return False + + dataset = (dataset + .select(filter_keys) + .decode('pil', handler=wds.warn_and_continue)) + try: + example = next(iter(dataset)) + except Exception: + print(f"Error @ {i_shard}") + + +if __name__ == "__main__": + #example01() + #example02() + #example03() + example04()