diff --git a/ldm/data/laion.py b/ldm/data/laion.py index f23ebb1..588549e 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -153,7 +153,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule): tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle, - handler=wds.warn_and_continue).shuffle(shuffle) + handler=wds.warn_and_continue).repeat().shuffle(shuffle) print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') dset = (dset