filter examples without req keys, make collation more robust
This commit is contained in:
		
							parent
							
								
									a95f78f056
								
							
						
					
					
						commit
						7d432123d5
					
				
					 1 changed files with 9 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -80,7 +80,8 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
 | 
			
		|||
    :returns: single sample consisting of a batch
 | 
			
		||||
    :rtype: dict
 | 
			
		||||
    """
 | 
			
		||||
    batched = {key: [] for key in samples[0]}
 | 
			
		||||
    keys = set.intersection(*[set(sample.keys()) for sample in samples])
 | 
			
		||||
    batched = {key: [] for key in keys}
 | 
			
		||||
 | 
			
		||||
    for s in samples:
 | 
			
		||||
        [batched[key].append(s[key]) for key in batched]
 | 
			
		||||
| 
						 | 
				
			
			@ -150,6 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
 | 
			
		|||
        print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
 | 
			
		||||
 | 
			
		||||
        dset = (dset
 | 
			
		||||
                .select(self.filter_keys)
 | 
			
		||||
                .decode('pil', handler=wds.warn_and_continue)
 | 
			
		||||
                .map_dict(**transform_dict, handler=wds.warn_and_continue)
 | 
			
		||||
                .batched(self.batch_size, partial=False,
 | 
			
		||||
| 
						 | 
				
			
			@ -161,6 +163,12 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
 | 
			
		|||
 | 
			
		||||
        return loader
 | 
			
		||||
 | 
			
		||||
    def filter_keys(self, x):
 | 
			
		||||
        try:
 | 
			
		||||
            return ("jpg" in x) and ("txt" in x)
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
    def train_dataloader(self):
 | 
			
		||||
        return self.make_loader(self.train)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue