filter examples without req keys, make collation more robust

This commit is contained in:
Patrick Esser 2022-06-01 22:04:45 +00:00 committed by root
parent a95f78f056
commit 7d432123d5
1 changed files with 9 additions and 1 deletions

View File

@ -80,7 +80,8 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
:returns: single sample consisting of a batch
:rtype: dict
"""
batched = {key: [] for key in samples[0]}
keys = set.intersection(*[set(sample.keys()) for sample in samples])
batched = {key: [] for key in keys}
for s in samples:
[batched[key].append(s[key]) for key in batched]
@ -150,6 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
dset = (dset
.select(self.filter_keys)
.decode('pil', handler=wds.warn_and_continue)
.map_dict(**transform_dict, handler=wds.warn_and_continue)
.batched(self.batch_size, partial=False,
@ -161,6 +163,12 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
return loader
def filter_keys(self, x):
try:
return ("jpg" in x) and ("txt" in x)
except Exception:
return False
def train_dataloader(self):
return self.make_loader(self.train)