diff --git a/ldm/data/laion.py b/ldm/data/laion.py index ddba0bb..d80434b 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -80,7 +80,8 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): :returns: single sample consisting of a batch :rtype: dict """ - batched = {key: [] for key in samples[0]} + keys = set.intersection(*[set(sample.keys()) for sample in samples]) + batched = {key: [] for key in keys} for s in samples: [batched[key].append(s[key]) for key in batched] @@ -150,6 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule): print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') dset = (dset + .select(self.filter_keys) .decode('pil', handler=wds.warn_and_continue) .map_dict(**transform_dict, handler=wds.warn_and_continue) .batched(self.batch_size, partial=False, @@ -161,6 +163,12 @@ class WebDataModuleFromConfig(pl.LightningDataModule): return loader + def filter_keys(self, x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + def train_dataloader(self): return self.make_loader(self.train)