diff --git a/configs/stable-diffusion/dev_mn.yaml b/configs/stable-diffusion/dev_mn.yaml index 7678b3e..3b3a0ef 100644 --- a/configs/stable-diffusion/dev_mn.yaml +++ b/configs/stable-diffusion/dev_mn.yaml @@ -81,7 +81,8 @@ data: num_workers: 4 n_nodes: 4 train: - shards: '{000000..231349}.tar -' + shards: '{000000..231339}.tar -' + shuffle: 10000 image_key: jpg image_transforms: - target: torchvision.transforms.Resize @@ -92,10 +93,9 @@ data: params: size: 256 - shuffle: 0 - n_examples: 100000 validation: - shards: '{000011..000012}.tar -' # TODO: wild guess, change + shards: '{231340..231349}.tar -' + shuffle: 0 image_key: jpg image_transforms: - target: torchvision.transforms.Resize @@ -106,10 +106,6 @@ data: params: size: 256 - shuffle: 0 - n_examples: 60000 # TODO: find out - - lightning: callbacks: diff --git a/ldm/data/laion.py b/ldm/data/laion.py index 6689c78..bfc7d3c 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -12,6 +12,7 @@ from tqdm import tqdm from omegaconf import OmegaConf from einops import rearrange import torch +from webdataset.handlers import warn_and_continue from ldm.util import instantiate_from_config @@ -27,12 +28,10 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): :rtype: dict """ batched = {key: [] for key in samples[0]} - # assert isinstance(samples[0][first_key], (list, tuple)), type(samples[first_key]) for s in samples: [batched[key].append(s[key]) for key in batched] - result = {} for key in batched: if isinstance(batched[key][0], (int, float)): @@ -40,21 +39,18 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): result[key] = np.array(list(batched[key])) elif isinstance(batched[key][0], torch.Tensor): if combine_tensors: - # import torch - result[key] = torch.stack(list(batched[key])) elif isinstance(batched[key][0], np.ndarray): if combine_tensors: result[key] = np.array(list(batched[key])) else: result[key] = list(batched[key]) - # result.append(b) return result class WebDataModuleFromConfig(pl.LightningDataModule): def __init__(self, tar_base, batch_size, train=None, validation=None, - test=None, num_workers=4, load_ddp=True, n_nodes=1, + test=None, num_workers=4, multinode=True, **kwargs): super().__init__(self) print(f'Setting tar base to {tar_base}') @@ -64,9 +60,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule): self.train = train self.validation = validation self.test = test - self.load_ddp = load_ddp - self.multinode = n_nodes > 1 - self.n_nodes = n_nodes # n gpu ?? + self.multinode = multinode def make_loader(self, dataset_config, train=True): if 'image_transforms' in dataset_config: @@ -83,101 +77,47 @@ class WebDataModuleFromConfig(pl.LightningDataModule): else: transforms_config = dict() - transform_dict = {dkey: load_partial_from_config(transforms_config[dkey]) if transforms_config[ - dkey] != 'identity' else identity - for dkey in transforms_config} + transform_dict = {dkey: load_partial_from_config(transforms_config[dkey]) + if transforms_config[dkey] != 'identity' else identity + for dkey in transforms_config} img_key = dataset_config.get('image_key', 'jpeg') transform_dict.update({img_key: image_transforms}) shuffle = dataset_config.get('shuffle', 0) + shardshuffle = shuffle > 0 - # TODO fid strategy when n exmples not known beforehand - n_examples = dataset_config.get('n_examples', 1e6) // self.n_nodes - - shards_to_load = dataset_config.shards - dset_name = 'unknown' - if isinstance(shards_to_load, str): - print(f'Loading tars based on the string {shards_to_load}') - tars = os.path.join(self.tar_base, shards_to_load) - start_shard_id, end_shard_id = dataset_config.shards.split('{')[-1].split('}')[0].split('..') - n_shards = int(end_shard_id) - int(start_shard_id) + 1 - dset_name = dataset_config.shards.split('-')[0] - elif isinstance(shards_to_load, int): - print(f'Creating tar list, max shard is {shards_to_load}') - try: - tars = [tf for tf in natsorted(glob(os.path.join(self.tar_base, '*.tar'))) if - int(tf.split('/')[-1].split('.')[0]) < shards_to_load] - n_shards = len(tars) - random.shuffle(tars) - - except ValueError as e: - print('tarfile names should follow the pattern .tar . Check names of the files') - raise e - else: - raise ValueError( - 'shards should be either a string containing consecutive shards or an int defining the max shard number') - - print(f'Got {n_shards} shard files in datafolder for {"training" if train else "validation"}') - - # if self.num_workers > 0: - # assert n_shards % self.num_workers == 0 , f'Number of workers which is {self.num_workers} does not evenly divide number of shards which is {n_shards}' - print(f'Loading webdataset based dataloader based on {n_shards} of {dset_name} dataset.') - - # start creating the dataset nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only - epoch_length = n_examples // (self.batch_size) - dset = wds.WebDataset(tars, nodesplitter=nodesplitter).shuffle(shuffle) - - with_epoch_args = {'nsamples': n_examples, 'nbatches': epoch_length} - - if 'filters' in dataset_config: - for stage in tqdm(dataset_config.filters, - desc=f'Applying the following filters: {[f for f in dataset_config.filters]}'): - f = getattr(dset, stage) - dset = f(dset, *dataset_config.filters[stage].args, - **dataset_config.filters[stage].get('kwargs', dict())) - - print(f'Dataset holding {len(dset.pipeline[0].urls)} shards') - - from webdataset.handlers import warn_and_continue + tars = os.path.join(self.tar_base, dataset_config.shards) + dset = wds.WebDataset( + tars, + nodesplitter=nodesplitter, + shardshuffle=shardshuffle).shuffle(shuffle) + print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') dset = (dset .decode('pil', handler=warn_and_continue) - # .to_tuple("jpg;png;jpeg pickle cls hls") - # .map_tuple(image_transforms,load_partial_from_config(nns_transform) if 'target' in nns_transform else identity,identity,identity) .map_dict(**transform_dict) - .repeat() .batched(self.batch_size, partial=False, collation_fn=dict_collation_fn) - .with_length(n_examples) - .with_epoch(**with_epoch_args) ) loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers) - return loader, n_examples + return loader def train_dataloader(self): - assert self.train is not None - loader, dset_size = self.make_loader(self.train) - # if self.load_ddp: - # loader = loader.ddp_equalize(dset_size // self.batch_size) - return loader + return self.make_loader(self.train) def val_dataloader(self): - assert self.train is not None - loader, _ = self.make_loader(self.validation, train=False) - return loader + return self.make_loader(self.validation, train=False) def test_dataloader(self): - assert self.train is not None - loader, _ = self.make_loader(self.test, train=False) - return loader + return self.make_loader(self.test, train=False) -if __name__ == "__main__": +def example00(): url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -" dataset = wds.WebDataset(url) example = next(iter(dataset)) @@ -207,3 +147,48 @@ if __name__ == "__main__": print(ex["image"].size, ex["text"]) if i >= 100: break + + +def example01(): + # the first laion shards contain ~10k examples each + url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..000002}.tar -" + + batch_size = 3 + shuffle_buffer = 10000 + dset = wds.WebDataset( + url, + nodesplitter=wds.shardlists.split_by_node, + shardshuffle=True, + ) + dset = (dset + .shuffle(shuffle_buffer, initial=shuffle_buffer) + .decode('pil', handler=warn_and_continue) + .batched(batch_size, partial=False, + collation_fn=dict_collation_fn) + ) + + num_workers = 2 + loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=num_workers) + + batch_sizes = list() + keys_per_epoch = list() + for epoch in range(5): + keys = list() + for batch in tqdm(loader): + batch_sizes.append(len(batch["__key__"])) + keys.append(batch["__key__"]) + + for bs in batch_sizes: + assert bs==batch_size + print(f"{len(batch_sizes)} batches of size {batch_size}.") + batch_sizes = list() + + keys_per_epoch.append(keys) + for i_batch in [0, 1, -1]: + print(f"Batch {i_batch} of epoch {epoch}:") + print(keys[i_batch]) + print("next epoch.") + + +if __name__ == "__main__": + example01()