multinode hacks
This commit is contained in:
parent
7f8c423450
commit
a7aad82e51
4 changed files with 27 additions and 12 deletions
|
@ -2,7 +2,6 @@ model:
|
||||||
base_learning_rate: 1.0e-04
|
base_learning_rate: 1.0e-04
|
||||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
params:
|
params:
|
||||||
ckpt_path: "/home/mchorse/stable-diffusion-ckpts/256pretrain-2022-06-09.ckpt"
|
|
||||||
linear_start: 0.00085
|
linear_start: 0.00085
|
||||||
linear_end: 0.0120
|
linear_end: 0.0120
|
||||||
num_timesteps_cond: 1
|
num_timesteps_cond: 1
|
||||||
|
@ -20,7 +19,7 @@ model:
|
||||||
scheduler_config: # 10000 warmup steps
|
scheduler_config: # 10000 warmup steps
|
||||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
params:
|
params:
|
||||||
warm_up_steps: [ 10000 ]
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
f_start: [ 1.e-6 ]
|
f_start: [ 1.e-6 ]
|
||||||
f_max: [ 1. ]
|
f_max: [ 1. ]
|
||||||
|
|
|
@ -151,12 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
||||||
if self.tar_base == "__improvedaesthetic__":
|
if self.tar_base == "__improvedaesthetic__":
|
||||||
print("## Warning, loading the same improved aesthetic dataset "
|
print("## Warning, loading the same improved aesthetic dataset "
|
||||||
"for all splits and ignoring shards parameter.")
|
"for all splits and ignoring shards parameter.")
|
||||||
urls = []
|
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
|
||||||
for i in range(1, 65):
|
|
||||||
for j in range(512):
|
|
||||||
for k in range(5):
|
|
||||||
urls.append(f's3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics/{i:02d}/{j:03d}/{k:05d}.tar')
|
|
||||||
tars = [f'pipe:aws s3 cp {url} -' for url in urls]
|
|
||||||
else:
|
else:
|
||||||
tars = os.path.join(self.tar_base, dataset_config.shards)
|
tars = os.path.join(self.tar_base, dataset_config.shards)
|
||||||
|
|
||||||
|
@ -314,7 +309,8 @@ if __name__ == "__main__":
|
||||||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||||
|
|
||||||
#config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
|
#config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
|
||||||
config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
|
#config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
|
||||||
|
config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml")
|
||||||
datamod = WebDataModuleFromConfig(**config["data"]["params"])
|
datamod = WebDataModuleFromConfig(**config["data"]["params"])
|
||||||
dataloader = datamod.train_dataloader()
|
dataloader = datamod.train_dataloader()
|
||||||
|
|
||||||
|
|
20
main.py
20
main.py
|
@ -21,6 +21,9 @@ from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
MULTINODE_HACKS = True
|
||||||
|
|
||||||
|
|
||||||
def get_parser(**parser_kwargs):
|
def get_parser(**parser_kwargs):
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
if isinstance(v, bool):
|
if isinstance(v, bool):
|
||||||
|
@ -268,6 +271,9 @@ class SetupCallback(Callback):
|
||||||
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
||||||
print("Project config")
|
print("Project config")
|
||||||
print(OmegaConf.to_yaml(self.config))
|
print(OmegaConf.to_yaml(self.config))
|
||||||
|
if MULTINODE_HACKS:
|
||||||
|
import time
|
||||||
|
time.sleep(5)
|
||||||
OmegaConf.save(self.config,
|
OmegaConf.save(self.config,
|
||||||
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
||||||
|
|
||||||
|
@ -278,7 +284,7 @@ class SetupCallback(Callback):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# ModelCheckpoint callback created log directory --- remove it
|
# ModelCheckpoint callback created log directory --- remove it
|
||||||
if not self.resume and os.path.exists(self.logdir):
|
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
|
||||||
dst, name = os.path.split(self.logdir)
|
dst, name = os.path.split(self.logdir)
|
||||||
dst = os.path.join(dst, "child_runs", name)
|
dst = os.path.join(dst, "child_runs", name)
|
||||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||||
|
@ -759,9 +765,19 @@ if __name__ == "__main__":
|
||||||
del callbacks_cfg['ignore_keys_callback']
|
del callbacks_cfg['ignore_keys_callback']
|
||||||
|
|
||||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||||
|
if not "plugins" in trainer_kwargs:
|
||||||
|
trainer_kwargs["plugins"] = list()
|
||||||
if not lightning_config.get("find_unused_parameters", True):
|
if not lightning_config.get("find_unused_parameters", True):
|
||||||
from pytorch_lightning.plugins import DDPPlugin
|
from pytorch_lightning.plugins import DDPPlugin
|
||||||
trainer_kwargs["plugins"] = DDPPlugin(find_unused_parameters=False)
|
trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False))
|
||||||
|
if MULTINODE_HACKS:
|
||||||
|
# disable resume from hpc ckpts
|
||||||
|
# NOTE below only works in later versions
|
||||||
|
# from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||||
|
# trainer_kwargs["plugins"].append(SLURMEnvironment(auto_requeue=False))
|
||||||
|
# hence we monkey patch things
|
||||||
|
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
|
||||||
|
setattr(CheckpointConnector, "hpc_resume_path", None)
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||||
trainer.logdir = logdir ###
|
trainer.logdir = logdir ###
|
||||||
|
|
|
@ -40,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -258,3 +258,7 @@ if __name__ == "__main__":
|
||||||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||||
f"Sampling took {toc-tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
|
f"Sampling took {toc-tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
|
||||||
f" \nEnjoy.")
|
f" \nEnjoy.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
Loading…
Reference in a new issue