From 171cf29fb54afe048b03ec73da8abb9d102d0614 Mon Sep 17 00:00:00 2001 From: ablattmann Date: Wed, 22 Dec 2021 15:57:23 +0100 Subject: [PATCH] add configs for training unconditional/class-conditional ldms --- README.md | 100 +++++++++++++++--- .../latent-diffusion/celebahq-ldm-vq-4.yaml | 86 +++++++++++++++ configs/latent-diffusion/cin-ldm-vq-f8.yaml | 98 +++++++++++++++++ configs/latent-diffusion/ffhq-ldm-vq-4.yaml | 85 +++++++++++++++ .../lsun_bedrooms-ldm-vq-4.yaml | 85 +++++++++++++++ ...r-ldm.yaml => lsun_churches-ldm-kl-8.yaml} | 10 +- ldm/models/diffusion/ddim.py | 9 +- ldm/models/diffusion/ddpm.py | 59 +++++++---- ldm/modules/diffusionmodules/util.py | 6 ++ main.py | 5 +- models/ldm/semantic_synthesis256/config.yaml | 59 +++++++++++ scripts/download_first_stages.sh | 4 +- scripts/download_models.sh | 9 +- 13 files changed, 562 insertions(+), 53 deletions(-) create mode 100644 configs/latent-diffusion/celebahq-ldm-vq-4.yaml create mode 100644 configs/latent-diffusion/cin-ldm-vq-f8.yaml create mode 100644 configs/latent-diffusion/ffhq-ldm-vq-4.yaml create mode 100644 configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml rename configs/latent-diffusion/{lsun_churches_f8-autoencoder-ldm.yaml => lsun_churches-ldm-kl-8.yaml} (82%) create mode 100644 models/ldm/semantic_synthesis256/config.yaml diff --git a/README.md b/README.md index 50d768a..b162fa2 100644 --- a/README.md +++ b/README.md @@ -55,18 +55,7 @@ bash scripts/download_first_stages.sh ``` The first stage models can then be found in `models/first_stage_models/` -### Training autoencoder models -Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`. -Training can be started by running -``` -CUDA_VISIBLE_DEVICES= python main.py --base configs/autoencoder/ -t --gpus 0, -``` -where `config_spec` is one of {`autoencoder_kl_8x8x64.yaml`(f=32, d=64), `autoencoder_kl_16x16x16.yaml`(f=16, d=16), -`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}. - -For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers) -repository. ## Pretrained LDMs @@ -78,9 +67,10 @@ repository. | LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | | | ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) | | Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION | -| OpenImages | Super-resolution | N/A | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation | +| OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation | | OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | | -| Landscapes (finetuned 512) | Semantic Image Synthesis | LDM-VQ-4 (100 DDIM steps, eta=1) | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | | +| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | | +| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 | ### Get the models @@ -116,10 +106,90 @@ python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inp `indir` should contain images `*.png` and masks `_mask.png` like the examples provided in `data/inpainting_examples`. + +# Train your own LDMs + +## Data preparation + +### Faces +For downloading the CelebA-HQ and FFHQ datasets, proceed as described in the [taming-transformers](https://github.com/CompVis/taming-transformers#celeba-hq) +repository. + +### LSUN + +The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun). +We performed a custom split into training and validation images, and provide the corresponding filenames +at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip). +After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should +also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively. + +### ImageNet +The code will try to download (through [Academic +Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it +is used. However, since ImageNet is quite large, this requires a lot of disk +space and time. If you already have ImageNet on your disk, you can speed things +up by putting the data into +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to +`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one +of `train`/`validation`. It should have the following structure: + +``` +${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/ +├── n01440764 +│ ├── n01440764_10026.JPEG +│ ├── n01440764_10027.JPEG +│ ├── ... +├── n01443537 +│ ├── n01443537_10007.JPEG +│ ├── n01443537_10014.JPEG +│ ├── ... +├── ... +``` + +If you haven't extracted the data, you can also place +`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` / +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be +extracted into above structure without downloading it again. Note that this +will only happen if neither a folder +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them +if you want to force running the dataset preparation again. + + +## Model Training + +Logs and checkpoints for trained models are saved to `logs/_`. + +### Training autoencoder models + +Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`. +Training can be started by running +``` +CUDA_VISIBLE_DEVICES= python main.py --base configs/autoencoder/.yaml -t --gpus 0, +``` +where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder_kl_16x16x16`(f=16, d=16), +`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}. + +For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers) +repository. + +### Training LDMs + +In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets. +Training can be started by running + +```shell script +CUDA_VISIBLE_DEVICES= python main.py --base configs/latent-diffusion/.yaml -t --gpus 0, +``` + +where ```` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),`ffhq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3), +`lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3), +`lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}. + ## Coming Soon... -* Code for training LDMs and the corresponding compression models. -* Inference scripts for conditional LDMs for various conditioning modalities. +* More inference scripts for conditional LDMs. * In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing * We will also release some further pretrained models. diff --git a/configs/latent-diffusion/celebahq-ldm-vq-4.yaml b/configs/latent-diffusion/celebahq-ldm-vq-4.yaml new file mode 100644 index 0000000..89b3df4 --- /dev/null +++ b/configs/latent-diffusion/celebahq-ldm-vq-4.yaml @@ -0,0 +1,86 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + image_size: 64 + channels: 3 + monitor: val/loss_simple_ema + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + # note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 64 for f4 + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ckpt_path: models/first_stage_models/vq-f4/model.ckpt + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: taming.data.faceshq.CelebAHQTrain + params: + size: 256 + validation: + target: taming.data.faceshq.CelebAHQValidation + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/latent-diffusion/cin-ldm-vq-f8.yaml b/configs/latent-diffusion/cin-ldm-vq-f8.yaml new file mode 100644 index 0000000..b8cd9e2 --- /dev/null +++ b/configs/latent-diffusion/cin-ldm-vq-f8.yaml @@ -0,0 +1,98 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: class_label + image_size: 32 + channels: 4 + cond_stage_trainable: true + conditioning_key: crossattn + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 + in_channels: 4 + out_channels: 4 + model_channels: 256 + attention_resolutions: + #note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 32 for f8 + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + num_head_channels: 32 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 512 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 4 + n_embed: 16384 + ckpt_path: configs/first_stage_models/vq-f8/model.yaml + ddconfig: + double_z: false + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: + - 32 + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 512 + key: class_label +data: + target: main.DataModuleFromConfig + params: + batch_size: 64 + num_workers: 12 + wrap: false + train: + target: ldm.data.imagenet.ImageNetTrain + params: + config: + size: 256 + validation: + target: ldm.data.imagenet.ImageNetValidation + params: + config: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/latent-diffusion/ffhq-ldm-vq-4.yaml b/configs/latent-diffusion/ffhq-ldm-vq-4.yaml new file mode 100644 index 0000000..1899e30 --- /dev/null +++ b/configs/latent-diffusion/ffhq-ldm-vq-4.yaml @@ -0,0 +1,85 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + image_size: 64 + channels: 3 + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + # note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 64 for f4 + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ckpt_path: configs/first_stage_models/vq-f4/model.yaml + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 42 + num_workers: 5 + wrap: false + train: + target: taming.data.faceshq.FFHQTrain + params: + size: 256 + validation: + target: taming.data.faceshq.FFHQValidation + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml b/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml new file mode 100644 index 0000000..c4ca66c --- /dev/null +++ b/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml @@ -0,0 +1,85 @@ +model: + base_learning_rate: 2.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0195 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + image_size: 64 + channels: 3 + monitor: val/loss_simple_ema + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 3 + out_channels: 3 + model_channels: 224 + attention_resolutions: + # note: this isn\t actually the resolution but + # the downsampling factor, i.e. this corresnponds to + # attention on spatial resolution 8,16,32, as the + # spatial reolution of the latents is 64 for f4 + - 8 + - 4 + - 2 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 3 + - 4 + num_head_channels: 32 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + ckpt_path: configs/first_stage_models/vq-f4/model.yaml + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: __is_unconditional__ +data: + target: main.DataModuleFromConfig + params: + batch_size: 48 + num_workers: 5 + wrap: false + train: + target: ldm.data.lsun.LSUNBedroomsTrain + params: + size: 256 + validation: + target: ldm.data.lsun.LSUNBedroomsValidation + params: + size: 256 + + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 5000 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/latent-diffusion/lsun_churches_f8-autoencoder-ldm.yaml b/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml similarity index 82% rename from configs/latent-diffusion/lsun_churches_f8-autoencoder-ldm.yaml rename to configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml index a2e99c6..18dc8c2 100644 --- a/configs/latent-diffusion/lsun_churches_f8-autoencoder-ldm.yaml +++ b/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml @@ -45,7 +45,7 @@ model: params: embed_dim: 4 monitor: "val/rec_loss" - ckpt_path: "/export/compvis-nfs/user/ablattma/logs/braket/2021-11-26T11-25-56_lsun_churches-convae-f8-ft_from_oi/checkpoints/step=000180071-fidfrechet_inception_distance=2.335.ckpt" + ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" ddconfig: double_z: True z_channels: 4 @@ -65,7 +65,7 @@ model: data: target: main.DataModuleFromConfig params: - batch_size: 24 # TODO: was 96 in our experiments + batch_size: 96 num_workers: 5 wrap: False train: @@ -82,14 +82,10 @@ lightning: image_logger: target: main.ImageLogger params: - batch_frequency: 1000 # TODO 5000 + batch_frequency: 5000 max_images: 8 increase_log_steps: False - metrics_over_trainsteps_checkpoint: - target: pytorch_lightning.callbacks.ModelCheckpoint - params: - every_n_train_steps: 20000 trainer: benchmark: True \ No newline at end of file diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 546e81a..01c0e56 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -5,8 +5,7 @@ import numpy as np from tqdm import tqdm from functools import partial -from ldm.models.diffusion.ddpm import noise_like -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like class DDIMSampler(object): @@ -27,8 +26,7 @@ class DDIMSampler(object): num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) alphas_cumprod = self.model.alphas_cumprod assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - - to_torch = partial(torch.tensor, dtype=torch.float32, device=self.model.device) + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) @@ -73,7 +71,8 @@ class DDIMSampler(object): corrector_kwargs=None, verbose=True, x_T=None, - log_every_t=100 + log_every_t=100, + **kwargs ): if conditioning is not None: if isinstance(conditioning, dict): diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 642204a..9835129 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -16,14 +16,14 @@ from contextlib import contextmanager from functools import partial from tqdm import tqdm from torchvision.utils import make_grid -from PIL import Image from pytorch_lightning.utilities.distributed import rank_zero_only from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.modules.ema import LitEma from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler __conditioning_keys__ = {'concat': 'c_concat', @@ -37,12 +37,6 @@ def disabled_train(self, mode=True): return self -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() - - def uniform_on_device(r1, r2, shape, device): return (r1 - r2) * torch.rand(*shape, device=device) + r2 @@ -119,6 +113,7 @@ class DDPM(pl.LightningModule): if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if exists(given_betas): @@ -1188,7 +1183,6 @@ class LatentDiffusion(DDPM): if start_T is not None: timesteps = min(timesteps, start_T) - print(timesteps, start_T) iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( range(0, timesteps)) @@ -1222,7 +1216,7 @@ class LatentDiffusion(DDPM): @torch.no_grad() def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, verbose=True, timesteps=None, quantize_denoised=False, - mask=None, x0=None, shape=None): + mask=None, x0=None, shape=None,**kwargs): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: @@ -1238,10 +1232,28 @@ class LatentDiffusion(DDPM): mask=mask, x0=x0) @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, sample_ddim=False, return_keys=None, + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, **kwargs): - # TODO: maybe add option for ddim sampling via DDIMSampler class + + use_ddim = ddim_steps is not None + log = dict() z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, @@ -1288,7 +1300,9 @@ class LatentDiffusion(DDPM): if sample: # get denoise row with self.ema_scope("Plotting"): - samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples if plot_denoise_rows: @@ -1299,8 +1313,11 @@ class LatentDiffusion(DDPM): self.first_stage_model, IdentityFirstStage): # also display when quantizing x0 while sampling with self.ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, - quantize_denoised=True) + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_x0_quantized"] = x_samples @@ -1312,19 +1329,17 @@ class LatentDiffusion(DDPM): mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask = mask[:, None, ...] with self.ema_scope("Plotting Inpaint"): - samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, - quantize_denoised=False, x0=z[:N], mask=mask) + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_inpainting"] = x_samples log["mask"] = mask - if plot_denoise_rows: - denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log["denoise_row_inpainting"] = denoise_grid # outpaint with self.ema_scope("Plotting Outpaint"): - samples = self.sample(cond=c, batch_size=N, return_intermediates=False, - quantize_denoised=False, x0=z[:N], mask=1. - mask) + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_outpainting"] = x_samples diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index b626b9a..a952e6c 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -259,3 +259,9 @@ class HybridConditioner(nn.Module): c_concat = self.concat_conditioner(c_concat) c_crossattn = self.crossattn_conditioner(c_crossattn) return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/main.py b/main.py index d3498b2..e8e18c1 100644 --- a/main.py +++ b/main.py @@ -676,7 +676,10 @@ if __name__ == "__main__": ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) else: ngpu = 1 - accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches or 1 + if 'accumulate_grad_batches' in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches + else: + accumulate_grad_batches = 1 print(f"accumulate_grad_batches = {accumulate_grad_batches}") lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches if opt.scale_lr: diff --git a/models/ldm/semantic_synthesis256/config.yaml b/models/ldm/semantic_synthesis256/config.yaml new file mode 100644 index 0000000..1a721cf --- /dev/null +++ b/models/ldm/semantic_synthesis256/config.yaml @@ -0,0 +1,59 @@ +model: + base_learning_rate: 1.0e-06 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.0015 + linear_end: 0.0205 + log_every_t: 100 + timesteps: 1000 + loss_type: l1 + first_stage_key: image + cond_stage_key: segmentation + image_size: 64 + channels: 3 + concat_mode: true + cond_stage_trainable: true + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 64 + in_channels: 6 + out_channels: 3 + model_channels: 128 + attention_resolutions: + - 32 + - 16 + - 8 + num_res_blocks: 2 + channel_mult: + - 1 + - 4 + - 8 + num_heads: 8 + first_stage_config: + target: ldm.models.autoencoder.VQModelInterface + params: + embed_dim: 3 + n_embed: 8192 + ddconfig: + double_z: false + z_channels: 3 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + cond_stage_config: + target: ldm.modules.encoders.modules.SpatialRescaler + params: + n_stages: 2 + in_channels: 182 + out_channels: 3 diff --git a/scripts/download_first_stages.sh b/scripts/download_first_stages.sh index 6e8b775..a8d79e9 100644 --- a/scripts/download_first_stages.sh +++ b/scripts/download_first_stages.sh @@ -4,10 +4,10 @@ wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/la wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip -wget -O models/first_stage_models/vq-f4-noattn/model.zip https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1 +wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip -wget -O models/first_stage_models/vq-f16/model.zip https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1 +wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip diff --git a/scripts/download_models.sh b/scripts/download_models.sh index a6e74de..84297d7 100644 --- a/scripts/download_models.sh +++ b/scripts/download_models.sh @@ -6,9 +6,10 @@ wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/la wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip +wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip -wget -O models/ldm/inpainting_big/last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1 +wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip @@ -33,10 +34,16 @@ unzip -o model.zip cd ../semantic_synthesis512 unzip -o model.zip +cd ../semantic_synthesis256 +unzip -o model.zip + cd ../bsr_sr unzip -o model.zip cd ../layout2img-openimages256 unzip -o model.zip +cd ../inpainting_big +unzip -o model.zip + cd ../..