filter examples without req keys, make collation more robust
This commit is contained in:
parent
a95f78f056
commit
7d432123d5
1 changed files with 9 additions and 1 deletions
|
@ -80,7 +80,8 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
||||||
:returns: single sample consisting of a batch
|
:returns: single sample consisting of a batch
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
batched = {key: [] for key in samples[0]}
|
keys = set.intersection(*[set(sample.keys()) for sample in samples])
|
||||||
|
batched = {key: [] for key in keys}
|
||||||
|
|
||||||
for s in samples:
|
for s in samples:
|
||||||
[batched[key].append(s[key]) for key in batched]
|
[batched[key].append(s[key]) for key in batched]
|
||||||
|
@ -150,6 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
|
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
|
||||||
|
|
||||||
dset = (dset
|
dset = (dset
|
||||||
|
.select(self.filter_keys)
|
||||||
.decode('pil', handler=wds.warn_and_continue)
|
.decode('pil', handler=wds.warn_and_continue)
|
||||||
.map_dict(**transform_dict, handler=wds.warn_and_continue)
|
.map_dict(**transform_dict, handler=wds.warn_and_continue)
|
||||||
.batched(self.batch_size, partial=False,
|
.batched(self.batch_size, partial=False,
|
||||||
|
@ -161,6 +163,12 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
def filter_keys(self, x):
|
||||||
|
try:
|
||||||
|
return ("jpg" in x) and ("txt" in x)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return self.make_loader(self.train)
|
return self.make_loader(self.train)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue