diff --git a/ldm/data/laion.py b/ldm/data/laion.py index d6b7522..053d583 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -154,6 +154,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule): dset = (dset .select(self.filter_keys) .decode('pil', handler=wds.warn_and_continue) + .select(self.filter_sizes) .map_dict(**transform_dict, handler=wds.warn_and_continue) .batched(self.batch_size, partial=False, collation_fn=dict_collation_fn) @@ -174,7 +175,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule): def filter_keys(self, x): try: - return ("jpg" in x) and ("txt" in x) and self.filter_size(x) + return ("jpg" in x) and ("txt" in x) except Exception: return False