diff --git a/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-256-pretraining.yaml b/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-256-pretraining.yaml index d6420a9..28ba235 100644 --- a/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-256-pretraining.yaml +++ b/configs/stable-diffusion/txt2img-multinode-clip-encoder-f16-256-pretraining.yaml @@ -77,6 +77,7 @@ data: batch_size: 50 # TODO: max out num_workers: 4 multinode: True + min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution train: shards: '{000000..231317}.tar -' shuffle: 10000 @@ -124,7 +125,6 @@ lightning: unconditional_guidance_label: [""] trainer: - replace_sampler_ddp: False # TODO: check this benchmark: True val_check_interval: 5000000 # really sorry num_sanity_val_steps: 0 diff --git a/ldm/data/laion.py b/ldm/data/laion.py index b63c5b8..d6b7522 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -104,7 +104,7 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): class WebDataModuleFromConfig(pl.LightningDataModule): def __init__(self, tar_base, batch_size, train=None, validation=None, - test=None, num_workers=4, multinode=True, + test=None, num_workers=4, multinode=True, min_size=None, **kwargs): super().__init__(self) print(f'Setting tar base to {tar_base}') @@ -115,6 +115,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule): self.validation = validation self.test = test self.multinode = multinode + self.min_size = min_size # filter out very small images def make_loader(self, dataset_config, train=True): if 'image_transforms' in dataset_config: @@ -163,9 +164,17 @@ class WebDataModuleFromConfig(pl.LightningDataModule): return loader + def filter_size(self, x): + if self.min_size is None: + return True + try: + return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size + except Exception: + return False + def filter_keys(self, x): try: - return ("jpg" in x) and ("txt" in x) + return ("jpg" in x) and ("txt" in x) and self.filter_size(x) except Exception: return False @@ -260,7 +269,6 @@ if __name__ == "__main__": from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator - config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml") datamod = WebDataModuleFromConfig(**config["data"]["params"]) dataloader = datamod.train_dataloader()