multinode hacks
This commit is contained in:
parent
7f8c423450
commit
a7aad82e51
4 changed files with 27 additions and 12 deletions
|
@ -2,7 +2,6 @@ model:
|
|||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
ckpt_path: "/home/mchorse/stable-diffusion-ckpts/256pretrain-2022-06-09.ckpt"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
|
@ -20,7 +19,7 @@ model:
|
|||
scheduler_config: # 10000 warmup steps
|
||||
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
|
|
|
@ -151,12 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
|
|||
if self.tar_base == "__improvedaesthetic__":
|
||||
print("## Warning, loading the same improved aesthetic dataset "
|
||||
"for all splits and ignoring shards parameter.")
|
||||
urls = []
|
||||
for i in range(1, 65):
|
||||
for j in range(512):
|
||||
for k in range(5):
|
||||
urls.append(f's3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics/{i:02d}/{j:03d}/{k:05d}.tar')
|
||||
tars = [f'pipe:aws s3 cp {url} -' for url in urls]
|
||||
tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
|
||||
else:
|
||||
tars = os.path.join(self.tar_base, dataset_config.shards)
|
||||
|
||||
|
@ -314,7 +309,8 @@ if __name__ == "__main__":
|
|||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||
|
||||
#config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
|
||||
config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
|
||||
#config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
|
||||
config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml")
|
||||
datamod = WebDataModuleFromConfig(**config["data"]["params"])
|
||||
dataloader = datamod.train_dataloader()
|
||||
|
||||
|
|
20
main.py
20
main.py
|
@ -21,6 +21,9 @@ from ldm.data.base import Txt2ImgIterableBaseDataset
|
|||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
MULTINODE_HACKS = True
|
||||
|
||||
|
||||
def get_parser(**parser_kwargs):
|
||||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
|
@ -268,6 +271,9 @@ class SetupCallback(Callback):
|
|||
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
||||
print("Project config")
|
||||
print(OmegaConf.to_yaml(self.config))
|
||||
if MULTINODE_HACKS:
|
||||
import time
|
||||
time.sleep(5)
|
||||
OmegaConf.save(self.config,
|
||||
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
||||
|
||||
|
@ -278,7 +284,7 @@ class SetupCallback(Callback):
|
|||
|
||||
else:
|
||||
# ModelCheckpoint callback created log directory --- remove it
|
||||
if not self.resume and os.path.exists(self.logdir):
|
||||
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
|
||||
dst, name = os.path.split(self.logdir)
|
||||
dst = os.path.join(dst, "child_runs", name)
|
||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
|
@ -759,9 +765,19 @@ if __name__ == "__main__":
|
|||
del callbacks_cfg['ignore_keys_callback']
|
||||
|
||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||
if not "plugins" in trainer_kwargs:
|
||||
trainer_kwargs["plugins"] = list()
|
||||
if not lightning_config.get("find_unused_parameters", True):
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
trainer_kwargs["plugins"] = DDPPlugin(find_unused_parameters=False)
|
||||
trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False))
|
||||
if MULTINODE_HACKS:
|
||||
# disable resume from hpc ckpts
|
||||
# NOTE below only works in later versions
|
||||
# from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
# trainer_kwargs["plugins"].append(SLURMEnvironment(auto_requeue=False))
|
||||
# hence we monkey patch things
|
||||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
|
||||
setattr(CheckpointConnector, "hpc_resume_path", None)
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
|
|
|
@ -40,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
|
|||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
|
@ -258,3 +258,7 @@ if __name__ == "__main__":
|
|||
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||
f"Sampling took {toc-tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
|
||||
f" \nEnjoy.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Reference in a new issue