f16 with size filtering
This commit is contained in:
parent
04bcc0f7de
commit
18a9202196
2 changed files with 12 additions and 4 deletions
|
@ -77,6 +77,7 @@ data:
|
|||
batch_size: 50 # TODO: max out
|
||||
num_workers: 4
|
||||
multinode: True
|
||||
min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
|
||||
train:
|
||||
shards: '{000000..231317}.tar -'
|
||||
shuffle: 10000
|
||||
|
@ -124,7 +125,6 @@ lightning:
|
|||
unconditional_guidance_label: [""]
|
||||
|
||||
trainer:
|
||||
replace_sampler_ddp: False # TODO: check this
|
||||
benchmark: True
|
||||
val_check_interval: 5000000 # really sorry
|
||||
num_sanity_val_steps: 0
|
||||
|
|
|
@ -104,7 +104,7 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
|||
|
||||
class WebDataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(self, tar_base, batch_size, train=None, validation=None,
|
||||
test=None, num_workers=4, multinode=True,
|
||||
test=None, num_workers=4, multinode=True, min_size=None,
|
||||
**kwargs):
|
||||
super().__init__(self)
|
||||
print(f'Setting tar base to {tar_base}')
|
||||
|
@ -115,6 +115,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
|||
self.validation = validation
|
||||
self.test = test
|
||||
self.multinode = multinode
|
||||
self.min_size = min_size # filter out very small images
|
||||
|
||||
def make_loader(self, dataset_config, train=True):
|
||||
if 'image_transforms' in dataset_config:
|
||||
|
@ -163,9 +164,17 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
|||
|
||||
return loader
|
||||
|
||||
def filter_size(self, x):
|
||||
if self.min_size is None:
|
||||
return True
|
||||
try:
|
||||
return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def filter_keys(self, x):
|
||||
try:
|
||||
return ("jpg" in x) and ("txt" in x)
|
||||
return ("jpg" in x) and ("txt" in x) and self.filter_size(x)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
@ -260,7 +269,6 @@ if __name__ == "__main__":
|
|||
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||
|
||||
|
||||
config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
|
||||
datamod = WebDataModuleFromConfig(**config["data"]["params"])
|
||||
dataloader = datamod.train_dataloader()
|
||||
|
|
Loading…
Reference in a new issue