f16 with size filtering

This commit is contained in:
rromb 2022-06-10 11:08:42 +02:00
parent 04bcc0f7de
commit 18a9202196
2 changed files with 12 additions and 4 deletions

View file

@ -77,6 +77,7 @@ data:
batch_size: 50 # TODO: max out batch_size: 50 # TODO: max out
num_workers: 4 num_workers: 4
multinode: True multinode: True
min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
train: train:
shards: '{000000..231317}.tar -' shards: '{000000..231317}.tar -'
shuffle: 10000 shuffle: 10000
@ -124,7 +125,6 @@ lightning:
unconditional_guidance_label: [""] unconditional_guidance_label: [""]
trainer: trainer:
replace_sampler_ddp: False # TODO: check this
benchmark: True benchmark: True
val_check_interval: 5000000 # really sorry val_check_interval: 5000000 # really sorry
num_sanity_val_steps: 0 num_sanity_val_steps: 0

View file

@ -104,7 +104,7 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
class WebDataModuleFromConfig(pl.LightningDataModule): class WebDataModuleFromConfig(pl.LightningDataModule):
def __init__(self, tar_base, batch_size, train=None, validation=None, def __init__(self, tar_base, batch_size, train=None, validation=None,
test=None, num_workers=4, multinode=True, test=None, num_workers=4, multinode=True, min_size=None,
**kwargs): **kwargs):
super().__init__(self) super().__init__(self)
print(f'Setting tar base to {tar_base}') print(f'Setting tar base to {tar_base}')
@ -115,6 +115,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
self.validation = validation self.validation = validation
self.test = test self.test = test
self.multinode = multinode self.multinode = multinode
self.min_size = min_size # filter out very small images
def make_loader(self, dataset_config, train=True): def make_loader(self, dataset_config, train=True):
if 'image_transforms' in dataset_config: if 'image_transforms' in dataset_config:
@ -163,9 +164,17 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
return loader return loader
def filter_size(self, x):
if self.min_size is None:
return True
try:
return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
except Exception:
return False
def filter_keys(self, x): def filter_keys(self, x):
try: try:
return ("jpg" in x) and ("txt" in x) return ("jpg" in x) and ("txt" in x) and self.filter_size(x)
except Exception: except Exception:
return False return False
@ -260,7 +269,6 @@ if __name__ == "__main__":
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml") config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
datamod = WebDataModuleFromConfig(**config["data"]["params"]) datamod = WebDataModuleFromConfig(**config["data"]["params"])
dataloader = datamod.train_dataloader() dataloader = datamod.train_dataloader()