diff --git a/configs/stable-diffusion/inpainting/v1-edgeinpainting.yaml b/configs/stable-diffusion/inpainting/v1-edgeinpainting.yaml new file mode 100644 index 0000000..0b11032 --- /dev/null +++ b/configs/stable-diffusion/inpainting/v1-edgeinpainting.yaml @@ -0,0 +1,157 @@ +model: + base_learning_rate: 7.5e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important + monitor: val/loss_simple_ema + scale_factor: 0.18215 + ckpt_path: "/fsx/stable-diffusion/stable-diffusion/checkpoints/v1pp/v1pp-flatlined-hr.ckpt" + + concat_keys: + - mask + - masked_image + - smoothing_strength + + c_concat_log_start: 1 + c_concat_log_end: 5 + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 10 # 4 data + 4 downscaled image + 1 mask + 1 strength + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: "__improvedaesthetic__" + batch_size: 2 + num_workers: 4 + multinode: True + min_size: 512 + max_pwatermark: 0.8 + train: + shards: '{00000..17279}.tar -' + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddEdge + params: + mode: "512train-large" + + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: '{17280..17535}.tar -' + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddEdge + params: + mode: "512train-large" + + +lightning: + find_unused_parameters: False + + modelcheckpoint: + params: + every_n_train_steps: 2000 + + callbacks: + image_logger: + target: main.ImageLogger + params: + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: [""] + ddim_steps: 100 # todo check these out for inpainting, + ddim_eta: 1.0 # todo check these out for inpainting, + + trainer: + benchmark: True + val_check_interval: 5000000 # really sorry + num_sanity_val_steps: 0 + accumulate_grad_batches: 2 diff --git a/environment.yaml b/environment.yaml index 3a8e42e..dd79e76 100644 --- a/environment.yaml +++ b/environment.yaml @@ -23,6 +23,7 @@ dependencies: - torch-fidelity==0.3.0 - transformers==4.3.1 - webdataset==0.2.5 + - kornia==0.6 - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e . diff --git a/ldm/data/laion.py b/ldm/data/laion.py index a32d699..5fb046d 100644 --- a/ldm/data/laion.py +++ b/ldm/data/laion.py @@ -1,4 +1,5 @@ import webdataset as wds +import kornia from PIL import Image import io import os @@ -258,6 +259,76 @@ class AddMask(PRNGMixin): return sample +class AddEdge(PRNGMixin): + def __init__(self, mode="512train", mask_edges=True): + super().__init__() + assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"' + self.make_mask = MASK_MODES[mode] + self.n_down_choices = [0, 1, 2] + self.sigma_choices = [1, 2, 3, 4, 5] + self.mask_edges = mask_edges + + @torch.no_grad() + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = sample['jpg'] + + mask = self.make_mask(self.prng, x.shape[0], x.shape[1]) + mask[mask < 0.5] = 0 + mask[mask > 0.5] = 1 + mask = torch.from_numpy(mask[..., None]) + sample['mask'] = mask + + n_down_idx = self.prng.choice(len(self.n_down_choices)) + sigma_idx = self.prng.choice(len(self.sigma_choices)) + + n_choices = len(self.n_down_choices)*len(self.sigma_choices) + raveled_idx = np.ravel_multi_index((n_down_idx, sigma_idx), + (len(self.n_down_choices), len(self.sigma_choices))) + normalized_idx = raveled_idx/(n_choices-1) + + n_down = self.n_down_choices[n_down_idx] + sigma = self.sigma_choices[sigma_idx] + + kernel_size = 4*sigma+1 + kernel_size = (kernel_size, kernel_size) + sigma = (sigma, sigma) + canny = kornia.filters.Canny( + low_threshold=0.1, + high_threshold=0.2, + kernel_size=kernel_size, + sigma=sigma, + hysteresis=True, + ) + y = (x+1.0)/2.0 # in 01 + y = y.unsqueeze(0).permute(0, 3, 1, 2).contiguous() + + # down + for i_down in range(n_down): + size = min(y.shape[-2], y.shape[-1])//2 + y = kornia.geometry.transform.resize(y, size, antialias=True) + + # edge + _, y = canny(y) + + if n_down > 0: + size = x.shape[0], x.shape[1] + y = kornia.geometry.transform.resize(y, size, interpolation="nearest") + + y = y.permute(0, 2, 3, 1)[0].expand(-1, -1, 3).contiguous() + y = y*2.0-1.0 + + if self.mask_edges: + sample['masked_image'] = y * (mask < 0.5) + else: + sample['masked_image'] = y + + # concat normalized idx + sample['smoothing_strength'] = torch.ones_like(sample['mask'])*normalized_idx + + return sample + + def example00(): url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -" dataset = wds.WebDataset(url) diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index b078983..9c023af 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -1607,6 +1607,8 @@ class LatentInpaintDiffusion(LatentDiffusion): concat_keys=("mask", "masked_image"), masked_image_key="masked_image", keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, *args, **kwargs ): ckpt_path = kwargs.pop("ckpt_path", None) @@ -1617,6 +1619,8 @@ class LatentInpaintDiffusion(LatentDiffusion): self.finetune_keys = finetune_keys self.concat_keys = concat_keys self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' if exists(ckpt_path): self.init_from_ckpt(ckpt_path, ignore_keys) @@ -1707,6 +1711,9 @@ class LatentInpaintDiffusion(LatentDiffusion): if ismap(xc): log["original_conditioning"] = self.to_rgb(xc) + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:,self.c_concat_log_start:self.c_concat_log_end]) + if plot_diffusion_rows: # get diffusion row diffusion_row = list() diff --git a/scripts/slurm/v1_edgeinpainting/launcher.sh b/scripts/slurm/v1_edgeinpainting/launcher.sh new file mode 100755 index 0000000..b5d836f --- /dev/null +++ b/scripts/slurm/v1_edgeinpainting/launcher.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# mpi version for node rank +H=`hostname` +THEID=`echo -e $HOSTNAMES | python3 -c "import sys;[sys.stdout.write(str(i)) for i,line in enumerate(next(sys.stdin).split(' ')) if line.strip() == '$H'.strip()]"` +export NODE_RANK=${THEID} +echo THEID=$THEID + +echo "##########################################" +echo MASTER_ADDR=${MASTER_ADDR} +echo MASTER_PORT=${MASTER_PORT} +echo NODE_RANK=${NODE_RANK} +echo WORLD_SIZE=${WORLD_SIZE} +echo "##########################################" +# debug environment worked great so we stick with it +# no magic there, just a miniconda python=3.9, pytorch=1.12, cudatoolkit=11.3 +# env with pip dependencies from stable diffusion's requirements.txt +eval "$(/fsx/stable-diffusion/debug/miniconda3/bin/conda shell.bash hook)" +#conda activate stable +# torch 1.11 to avoid bug in ckpt restoring +conda activate torch111 +cd /fsx/stable-diffusion/stable-diffusion + +CONFIG="/fsx/stable-diffusion/stable-diffusion/configs/stable-diffusion/inpainting/v1-edgeinpainting.yaml" + +# resume and set new seed to reshuffle data +#EXTRA="--seed 543 --resume_from_checkpoint ..." + +# reduce lr a bit +#EXTRA="${EXTRA} model.params.scheduler_config.params.f_max=[0.75]" + +# custom logdir +#EXTRA="${EXTRA} --logdir rlogs" + +# debugging +#EXTRA="${EXTRA} -d True lightning.callbacks.image_logger.params.batch_frequency=50" + +# detect bad gpus early on +/bin/bash /fsx/stable-diffusion/stable-diffusion/scripts/test_gpu.sh + +python main.py --base $CONFIG --gpus 0,1,2,3,4,5,6,7 -t --num_nodes ${WORLD_SIZE} --scale_lr False diff --git a/scripts/slurm/v1_edgeinpainting/sbatch.sh b/scripts/slurm/v1_edgeinpainting/sbatch.sh new file mode 100755 index 0000000..e0a4ec8 --- /dev/null +++ b/scripts/slurm/v1_edgeinpainting/sbatch.sh @@ -0,0 +1,42 @@ +#!/bin/bash +#SBATCH --partition=compute-od-gpu +#SBATCH --job-name=stable-diffusion-v1-edgeinpainting +#SBATCH --nodes 24 +#SBATCH --ntasks-per-node 1 +#SBATCH --cpus-per-gpu=4 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --output=%x_%j.out +#SBATCH --comment "Key=Monitoring,Value=ON" +#SBATCH --no-requeue + +module load intelmpi +source /opt/intel/mpi/latest/env/vars.sh +export LD_LIBRARY_PATH=/opt/aws-ofi-nccl/lib:/opt/amazon/efa/lib64:/usr/local/cuda-11.0/efa/lib:/usr/local/cuda-11.0/lib:/usr/local/cuda-11.0/lib64:/usr/local/cuda-11.0:/opt/nccl/build/lib:/opt/aws-ofi-nccl-install/lib:/opt/aws-ofi-nccl/lib:$LD_LIBRARY_PATH +export NCCL_PROTO=simple +export PATH=/opt/amazon/efa/bin:$PATH +export LD_PRELOAD="/opt/nccl/build/lib/libnccl.so" +export FI_EFA_FORK_SAFE=1 +export FI_LOG_LEVEL=1 +export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn +export NCCL_DEBUG=info +export PYTHONFAULTHANDLER=1 +export CUDA_LAUNCH_BLOCKING=0 +export OMPI_MCA_mtl_base_verbose=1 +export FI_EFA_ENABLE_SHM_TRANSFER=0 +export FI_PROVIDER=efa +export FI_EFA_TX_MIN_CREDITS=64 +export NCCL_TREE_THRESHOLD=0 + +# sent to sub script +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=12802 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` +export WORLD_SIZE=$COUNT_NODE + +echo go $COUNT_NODE +echo $HOSTNAMES +echo $WORLD_SIZE + +mpirun -n $COUNT_NODE -perhost 1 /fsx/stable-diffusion/stable-diffusion/scripts/slurm/v1_edgeinpainting/launcher.sh