diff --git a/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml b/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml index bb0c934..15f7739 100644 --- a/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml +++ b/configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml @@ -2,7 +2,6 @@ model: base_learning_rate: 1.0e-04 target: ldm.models.diffusion.ddpm.LatentDiffusion params: - ckpt_path: "/home/mchorse/stable-diffusion-ckpts/256pretrain-2022-06-09.ckpt" linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 @@ -20,7 +19,7 @@ model: scheduler_config: # 10000 warmup steps target: ldm.lr_scheduler.LambdaLinearScheduler params: - warm_up_steps: [ 10000 ] + warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases f_start: [ 1.e-6 ] f_max: [ 1. ] diff --git a/ldm/data/laion.py b/ldm/data/laion.py index 73d928b..41d6fa3 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -151,12 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule): if self.tar_base == "__improvedaesthetic__": print("## Warning, loading the same improved aesthetic dataset " "for all splits and ignoring shards parameter.") - urls = [] - for i in range(1, 65): - for j in range(512): - for k in range(5): - urls.append(f's3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics/{i:02d}/{j:03d}/{k:05d}.tar') - tars = [f'pipe:aws s3 cp {url} -' for url in urls] + tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -" else: tars = os.path.join(self.tar_base, dataset_config.shards) @@ -314,7 +309,8 @@ if __name__ == "__main__": from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator #config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml") - config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml") + #config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml") + config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml") datamod = WebDataModuleFromConfig(**config["data"]["params"]) dataloader = datamod.train_dataloader() diff --git a/main.py b/main.py index e8946a5..b274bb3 100644 --- a/main.py +++ b/main.py @@ -21,6 +21,9 @@ from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.util import instantiate_from_config +MULTINODE_HACKS = True + + def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): @@ -268,6 +271,9 @@ class SetupCallback(Callback): os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) print("Project config") print(OmegaConf.to_yaml(self.config)) + if MULTINODE_HACKS: + import time + time.sleep(5) OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) @@ -278,7 +284,7 @@ class SetupCallback(Callback): else: # ModelCheckpoint callback created log directory --- remove it - if not self.resume and os.path.exists(self.logdir): + if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir): dst, name = os.path.split(self.logdir) dst = os.path.join(dst, "child_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) @@ -759,9 +765,19 @@ if __name__ == "__main__": del callbacks_cfg['ignore_keys_callback'] trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + if not "plugins" in trainer_kwargs: + trainer_kwargs["plugins"] = list() if not lightning_config.get("find_unused_parameters", True): from pytorch_lightning.plugins import DDPPlugin - trainer_kwargs["plugins"] = DDPPlugin(find_unused_parameters=False) + trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False)) + if MULTINODE_HACKS: + # disable resume from hpc ckpts + # NOTE below only works in later versions + # from pytorch_lightning.plugins.environments import SLURMEnvironment + # trainer_kwargs["plugins"].append(SLURMEnvironment(auto_requeue=False)) + # hence we monkey patch things + from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector + setattr(CheckpointConnector, "hpc_resume_path", None) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer.logdir = logdir ### diff --git a/scripts/txt2img.py b/scripts/txt2img.py index cbf8525..6e98f83 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -40,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False): return model -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -258,3 +258,7 @@ if __name__ == "__main__": print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f"Sampling took {toc-tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec." f" \nEnjoy.") + + +if __name__ == "__main__": + main()