f16 with size filtering
This commit is contained in:
parent
04bcc0f7de
commit
18a9202196
2 changed files with 12 additions and 4 deletions
|
@ -77,6 +77,7 @@ data:
|
||||||
batch_size: 50 # TODO: max out
|
batch_size: 50 # TODO: max out
|
||||||
num_workers: 4
|
num_workers: 4
|
||||||
multinode: True
|
multinode: True
|
||||||
|
min_size: 256 # TODO: experiment. Note: for 2B, images are stored at max 384 resolution
|
||||||
train:
|
train:
|
||||||
shards: '{000000..231317}.tar -'
|
shards: '{000000..231317}.tar -'
|
||||||
shuffle: 10000
|
shuffle: 10000
|
||||||
|
@ -124,7 +125,6 @@ lightning:
|
||||||
unconditional_guidance_label: [""]
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
trainer:
|
trainer:
|
||||||
replace_sampler_ddp: False # TODO: check this
|
|
||||||
benchmark: True
|
benchmark: True
|
||||||
val_check_interval: 5000000 # really sorry
|
val_check_interval: 5000000 # really sorry
|
||||||
num_sanity_val_steps: 0
|
num_sanity_val_steps: 0
|
||||||
|
|
|
@ -104,7 +104,7 @@ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
|
||||||
|
|
||||||
class WebDataModuleFromConfig(pl.LightningDataModule):
|
class WebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
def __init__(self, tar_base, batch_size, train=None, validation=None,
|
def __init__(self, tar_base, batch_size, train=None, validation=None,
|
||||||
test=None, num_workers=4, multinode=True,
|
test=None, num_workers=4, multinode=True, min_size=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(self)
|
super().__init__(self)
|
||||||
print(f'Setting tar base to {tar_base}')
|
print(f'Setting tar base to {tar_base}')
|
||||||
|
@ -115,6 +115,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
self.validation = validation
|
self.validation = validation
|
||||||
self.test = test
|
self.test = test
|
||||||
self.multinode = multinode
|
self.multinode = multinode
|
||||||
|
self.min_size = min_size # filter out very small images
|
||||||
|
|
||||||
def make_loader(self, dataset_config, train=True):
|
def make_loader(self, dataset_config, train=True):
|
||||||
if 'image_transforms' in dataset_config:
|
if 'image_transforms' in dataset_config:
|
||||||
|
@ -163,9 +164,17 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
def filter_size(self, x):
|
||||||
|
if self.min_size is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
def filter_keys(self, x):
|
def filter_keys(self, x):
|
||||||
try:
|
try:
|
||||||
return ("jpg" in x) and ("txt" in x)
|
return ("jpg" in x) and ("txt" in x) and self.filter_size(x)
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -260,7 +269,6 @@ if __name__ == "__main__":
|
||||||
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
|
||||||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||||
|
|
||||||
|
|
||||||
config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
|
config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
|
||||||
datamod = WebDataModuleFromConfig(**config["data"]["params"])
|
datamod = WebDataModuleFromConfig(**config["data"]["params"])
|
||||||
dataloader = datamod.train_dataloader()
|
dataloader = datamod.train_dataloader()
|
||||||
|
|
Loading…
Reference in a new issue