filter examples without req keys, make collation more robust

This commit is contained in:
Patrick Esser 2022-06-01 22:04:45 +00:00 committed by root
parent a95f78f056
commit 7d432123d5

View file

@ -80,7 +80,8 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
:returns: single sample consisting of a batch :returns: single sample consisting of a batch
:rtype: dict :rtype: dict
""" """
batched = {key: [] for key in samples[0]} keys = set.intersection(*[set(sample.keys()) for sample in samples])
batched = {key: [] for key in keys}
for s in samples: for s in samples:
[batched[key].append(s[key]) for key in batched] [batched[key].append(s[key]) for key in batched]
@ -150,6 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
dset = (dset dset = (dset
.select(self.filter_keys)
.decode('pil', handler=wds.warn_and_continue) .decode('pil', handler=wds.warn_and_continue)
.map_dict(**transform_dict, handler=wds.warn_and_continue) .map_dict(**transform_dict, handler=wds.warn_and_continue)
.batched(self.batch_size, partial=False, .batched(self.batch_size, partial=False,
@ -161,6 +163,12 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
return loader return loader
def filter_keys(self, x):
try:
return ("jpg" in x) and ("txt" in x)
except Exception:
return False
def train_dataloader(self): def train_dataloader(self):
return self.make_loader(self.train) return self.make_loader(self.train)