multinode hacks

This commit is contained in:
Patrick Esser 2022-07-14 21:29:46 +00:00 committed by pesser
parent 7f8c423450
commit a7aad82e51
4 changed files with 27 additions and 12 deletions

View File

@ -2,7 +2,6 @@ model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
ckpt_path: "/home/mchorse/stable-diffusion-ckpts/256pretrain-2022-06-09.ckpt"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
@ -20,7 +19,7 @@ model:
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]

View File

@ -151,12 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
if self.tar_base == "__improvedaesthetic__":
print("## Warning, loading the same improved aesthetic dataset "
"for all splits and ignoring shards parameter.")
urls = []
for i in range(1, 65):
for j in range(512):
for k in range(5):
urls.append(f's3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics/{i:02d}/{j:03d}/{k:05d}.tar')
tars = [f'pipe:aws s3 cp {url} -' for url in urls]
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
else:
tars = os.path.join(self.tar_base, dataset_config.shards)
@ -314,7 +309,8 @@ if __name__ == "__main__":
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
#config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
#config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml")
datamod = WebDataModuleFromConfig(**config["data"]["params"])
dataloader = datamod.train_dataloader()

20
main.py
View File

@ -21,6 +21,9 @@ from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
MULTINODE_HACKS = True
def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
@ -268,6 +271,9 @@ class SetupCallback(Callback):
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
print("Project config")
print(OmegaConf.to_yaml(self.config))
if MULTINODE_HACKS:
import time
time.sleep(5)
OmegaConf.save(self.config,
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
@ -278,7 +284,7 @@ class SetupCallback(Callback):
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
@ -759,9 +765,19 @@ if __name__ == "__main__":
del callbacks_cfg['ignore_keys_callback']
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
if not "plugins" in trainer_kwargs:
trainer_kwargs["plugins"] = list()
if not lightning_config.get("find_unused_parameters", True):
from pytorch_lightning.plugins import DDPPlugin
trainer_kwargs["plugins"] = DDPPlugin(find_unused_parameters=False)
trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False))
if MULTINODE_HACKS:
# disable resume from hpc ckpts
# NOTE below only works in later versions
# from pytorch_lightning.plugins.environments import SLURMEnvironment
# trainer_kwargs["plugins"].append(SLURMEnvironment(auto_requeue=False))
# hence we monkey patch things
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
setattr(CheckpointConnector, "hpc_resume_path", None)
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###

View File

@ -40,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
return model
if __name__ == "__main__":
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
@ -258,3 +258,7 @@ if __name__ == "__main__":
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f"Sampling took {toc-tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
f" \nEnjoy.")
if __name__ == "__main__":
main()