diff --git a/ldm/data/laion.py b/ldm/data/laion.py index 588549e..73d928b 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -148,7 +148,18 @@ class WebDataModuleFromConfig(pl.LightningDataModule): nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only - tars = os.path.join(self.tar_base, dataset_config.shards) + if self.tar_base == "__improvedaesthetic__": + print("## Warning, loading the same improved aesthetic dataset " + "for all splits and ignoring shards parameter.") + urls = [] + for i in range(1, 65): + for j in range(512): + for k in range(5): + urls.append(f's3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics/{i:02d}/{j:03d}/{k:05d}.tar') + tars = [f'pipe:aws s3 cp {url} -' for url in urls] + else: + tars = os.path.join(self.tar_base, dataset_config.shards) + dset = wds.WebDataset( tars, nodesplitter=nodesplitter,