filter examples without req keys, make collation more robust
This commit is contained in:
parent
a95f78f056
commit
7d432123d5
1 changed files with 9 additions and 1 deletions
|
@ -80,7 +80,8 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
|||
:returns: single sample consisting of a batch
|
||||
:rtype: dict
|
||||
"""
|
||||
batched = {key: [] for key in samples[0]}
|
||||
keys = set.intersection(*[set(sample.keys()) for sample in samples])
|
||||
batched = {key: [] for key in keys}
|
||||
|
||||
for s in samples:
|
||||
[batched[key].append(s[key]) for key in batched]
|
||||
|
@ -150,6 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
|||
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
|
||||
|
||||
dset = (dset
|
||||
.select(self.filter_keys)
|
||||
.decode('pil', handler=wds.warn_and_continue)
|
||||
.map_dict(**transform_dict, handler=wds.warn_and_continue)
|
||||
.batched(self.batch_size, partial=False,
|
||||
|
@ -161,6 +163,12 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
|||
|
||||
return loader
|
||||
|
||||
def filter_keys(self, x):
|
||||
try:
|
||||
return ("jpg" in x) and ("txt" in x)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def train_dataloader(self):
|
||||
return self.make_loader(self.train)
|
||||
|
||||
|
|
Loading…
Reference in a new issue