multinode hacks
This commit is contained in:
		
							parent
							
								
									7f8c423450
								
							
						
					
					
						commit
						a7aad82e51
					
				
					 4 changed files with 27 additions and 12 deletions
				
			
		| 
						 | 
					@ -2,7 +2,6 @@ model:
 | 
				
			||||||
  base_learning_rate: 1.0e-04
 | 
					  base_learning_rate: 1.0e-04
 | 
				
			||||||
  target: ldm.models.diffusion.ddpm.LatentDiffusion
 | 
					  target: ldm.models.diffusion.ddpm.LatentDiffusion
 | 
				
			||||||
  params:
 | 
					  params:
 | 
				
			||||||
    ckpt_path: "/home/mchorse/stable-diffusion-ckpts/256pretrain-2022-06-09.ckpt"
 | 
					 | 
				
			||||||
    linear_start: 0.00085
 | 
					    linear_start: 0.00085
 | 
				
			||||||
    linear_end: 0.0120
 | 
					    linear_end: 0.0120
 | 
				
			||||||
    num_timesteps_cond: 1
 | 
					    num_timesteps_cond: 1
 | 
				
			||||||
| 
						 | 
					@ -20,7 +19,7 @@ model:
 | 
				
			||||||
    scheduler_config: # 10000 warmup steps
 | 
					    scheduler_config: # 10000 warmup steps
 | 
				
			||||||
      target: ldm.lr_scheduler.LambdaLinearScheduler
 | 
					      target: ldm.lr_scheduler.LambdaLinearScheduler
 | 
				
			||||||
      params:
 | 
					      params:
 | 
				
			||||||
        warm_up_steps: [ 10000 ]
 | 
					        warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
 | 
				
			||||||
        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
 | 
					        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
 | 
				
			||||||
        f_start: [ 1.e-6 ]
 | 
					        f_start: [ 1.e-6 ]
 | 
				
			||||||
        f_max: [ 1. ]
 | 
					        f_max: [ 1. ]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -151,12 +151,7 @@ class WebDataModuleFromConfig(pl.LightningDataModule):
 | 
				
			||||||
        if self.tar_base == "__improvedaesthetic__":
 | 
					        if self.tar_base == "__improvedaesthetic__":
 | 
				
			||||||
            print("## Warning, loading the same improved aesthetic dataset "
 | 
					            print("## Warning, loading the same improved aesthetic dataset "
 | 
				
			||||||
                    "for all splits and ignoring shards parameter.")
 | 
					                    "for all splits and ignoring shards parameter.")
 | 
				
			||||||
            urls = []
 | 
					            tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
 | 
				
			||||||
            for i in range(1, 65):
 | 
					 | 
				
			||||||
                for j in range(512):
 | 
					 | 
				
			||||||
                    for k in range(5):
 | 
					 | 
				
			||||||
                        urls.append(f's3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics/{i:02d}/{j:03d}/{k:05d}.tar')
 | 
					 | 
				
			||||||
            tars = [f'pipe:aws s3 cp {url} -' for url in urls]
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            tars = os.path.join(self.tar_base, dataset_config.shards)
 | 
					            tars = os.path.join(self.tar_base, dataset_config.shards)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -314,7 +309,8 @@ if __name__ == "__main__":
 | 
				
			||||||
    from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
 | 
					    from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
 | 
					    #config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
 | 
				
			||||||
    config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
 | 
					    #config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
 | 
				
			||||||
 | 
					    config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml")
 | 
				
			||||||
    datamod = WebDataModuleFromConfig(**config["data"]["params"])
 | 
					    datamod = WebDataModuleFromConfig(**config["data"]["params"])
 | 
				
			||||||
    dataloader = datamod.train_dataloader()
 | 
					    dataloader = datamod.train_dataloader()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										20
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										20
									
								
								main.py
									
									
									
									
									
								
							| 
						 | 
					@ -21,6 +21,9 @@ from ldm.data.base import Txt2ImgIterableBaseDataset
 | 
				
			||||||
from ldm.util import instantiate_from_config
 | 
					from ldm.util import instantiate_from_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					MULTINODE_HACKS = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_parser(**parser_kwargs):
 | 
					def get_parser(**parser_kwargs):
 | 
				
			||||||
    def str2bool(v):
 | 
					    def str2bool(v):
 | 
				
			||||||
        if isinstance(v, bool):
 | 
					        if isinstance(v, bool):
 | 
				
			||||||
| 
						 | 
					@ -268,6 +271,9 @@ class SetupCallback(Callback):
 | 
				
			||||||
                    os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
 | 
					                    os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
 | 
				
			||||||
            print("Project config")
 | 
					            print("Project config")
 | 
				
			||||||
            print(OmegaConf.to_yaml(self.config))
 | 
					            print(OmegaConf.to_yaml(self.config))
 | 
				
			||||||
 | 
					            if MULTINODE_HACKS:
 | 
				
			||||||
 | 
					                import time
 | 
				
			||||||
 | 
					                time.sleep(5)
 | 
				
			||||||
            OmegaConf.save(self.config,
 | 
					            OmegaConf.save(self.config,
 | 
				
			||||||
                           os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
 | 
					                           os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -278,7 +284,7 @@ class SetupCallback(Callback):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            # ModelCheckpoint callback created log directory --- remove it
 | 
					            # ModelCheckpoint callback created log directory --- remove it
 | 
				
			||||||
            if not self.resume and os.path.exists(self.logdir):
 | 
					            if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir):
 | 
				
			||||||
                dst, name = os.path.split(self.logdir)
 | 
					                dst, name = os.path.split(self.logdir)
 | 
				
			||||||
                dst = os.path.join(dst, "child_runs", name)
 | 
					                dst = os.path.join(dst, "child_runs", name)
 | 
				
			||||||
                os.makedirs(os.path.split(dst)[0], exist_ok=True)
 | 
					                os.makedirs(os.path.split(dst)[0], exist_ok=True)
 | 
				
			||||||
| 
						 | 
					@ -759,9 +765,19 @@ if __name__ == "__main__":
 | 
				
			||||||
            del callbacks_cfg['ignore_keys_callback']
 | 
					            del callbacks_cfg['ignore_keys_callback']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
 | 
					        trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
 | 
				
			||||||
 | 
					        if not "plugins" in trainer_kwargs:
 | 
				
			||||||
 | 
					            trainer_kwargs["plugins"] = list()
 | 
				
			||||||
        if not lightning_config.get("find_unused_parameters", True):
 | 
					        if not lightning_config.get("find_unused_parameters", True):
 | 
				
			||||||
            from pytorch_lightning.plugins import DDPPlugin
 | 
					            from pytorch_lightning.plugins import DDPPlugin
 | 
				
			||||||
            trainer_kwargs["plugins"] = DDPPlugin(find_unused_parameters=False)
 | 
					            trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False))
 | 
				
			||||||
 | 
					        if MULTINODE_HACKS:
 | 
				
			||||||
 | 
					            # disable resume from hpc ckpts
 | 
				
			||||||
 | 
					            # NOTE below only works in later versions
 | 
				
			||||||
 | 
					            # from pytorch_lightning.plugins.environments import SLURMEnvironment
 | 
				
			||||||
 | 
					            # trainer_kwargs["plugins"].append(SLURMEnvironment(auto_requeue=False))
 | 
				
			||||||
 | 
					            # hence we monkey patch things
 | 
				
			||||||
 | 
					            from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
 | 
				
			||||||
 | 
					            setattr(CheckpointConnector, "hpc_resume_path", None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
 | 
					        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
 | 
				
			||||||
        trainer.logdir = logdir  ###
 | 
					        trainer.logdir = logdir  ###
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
 | 
				
			||||||
    return model
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					def main():
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    parser.add_argument(
 | 
					    parser.add_argument(
 | 
				
			||||||
| 
						 | 
					@ -258,3 +258,7 @@ if __name__ == "__main__":
 | 
				
			||||||
    print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
 | 
					    print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
 | 
				
			||||||
          f"Sampling took {toc-tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
 | 
					          f"Sampling took {toc-tic}s, i.e. produced {opt.n_iter * opt.n_samples / (toc - tic):.2f} samples/sec."
 | 
				
			||||||
          f" \nEnjoy.")
 | 
					          f" \nEnjoy.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    main()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue