add new models
This commit is contained in:
parent
2b46bcb98c
commit
a80ed7987e
13 changed files with 2056 additions and 80 deletions
169
README.md
169
README.md
|
@ -28,71 +28,46 @@ conda env create -f environment.yaml
|
||||||
conda activate ldm
|
conda activate ldm
|
||||||
```
|
```
|
||||||
|
|
||||||
# Model Zoo
|
# Pretrained Models
|
||||||
|
A general list of all available checkpoints is available in via our [model zoo](#model-zoo).
|
||||||
|
If you use any of these models in your work, we are always happy to receive a [citation](#bibtex).
|
||||||
|
|
||||||
## Pretrained Autoencoding Models
|
## Text-to-Image
|
||||||
![rec2](assets/reconstruction2.png)
|
![text2img-figure](assets/txt2img-preview.png)
|
||||||
|
|
||||||
All models were trained until convergence (no further substantial improvement in rFID).
|
|
||||||
|
|
||||||
| Model | rFID vs val | train steps |PSNR | PSIM | Link | Comments
|
Download the pre-trained weights (5.7GB)
|
||||||
|-------------------------|------------|----------------|----------------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------|
|
|
||||||
| f=4, VQ (Z=8192, d=3) | 0.58 | 533066 | 27.43 +/- 4.26 | 0.53 +/- 0.21 | https://ommer-lab.com/files/latent-diffusion/vq-f4.zip | |
|
|
||||||
| f=4, VQ (Z=8192, d=3) | 1.06 | 658131 | 25.21 +/- 4.17 | 0.72 +/- 0.26 | https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1 | no attention |
|
|
||||||
| f=8, VQ (Z=16384, d=4) | 1.14 | 971043 | 23.07 +/- 3.99 | 1.17 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | |
|
|
||||||
| f=8, VQ (Z=256, d=4) | 1.49 | 1608649 | 22.35 +/- 3.81 | 1.26 +/- 0.37 | https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip |
|
|
||||||
| f=16, VQ (Z=16384, d=8) | 5.15 | 1101166 | 20.83 +/- 3.61 | 1.73 +/- 0.43 | https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1 | |
|
|
||||||
| | | | | | | |
|
|
||||||
| f=4, KL | 0.27 | 176991 | 27.53 +/- 4.54 | 0.55 +/- 0.24 | https://ommer-lab.com/files/latent-diffusion/kl-f4.zip | |
|
|
||||||
| f=8, KL | 0.90 | 246803 | 24.19 +/- 4.19 | 1.02 +/- 0.35 | https://ommer-lab.com/files/latent-diffusion/kl-f8.zip | |
|
|
||||||
| f=16, KL (d=16) | 0.87 | 442998 | 24.08 +/- 4.22 | 1.07 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/kl-f16.zip | |
|
|
||||||
| f=32, KL (d=64) | 2.04 | 406763 | 22.27 +/- 3.93 | 1.41 +/- 0.40 | https://ommer-lab.com/files/latent-diffusion/kl-f32.zip | |
|
|
||||||
|
|
||||||
### Get the models
|
|
||||||
|
|
||||||
Running the following script downloads und extracts all available pretrained autoencoding models.
|
|
||||||
```shell script
|
|
||||||
bash scripts/download_first_stages.sh
|
|
||||||
```
|
```
|
||||||
|
mkdir -p models/ldm/text2img-large/
|
||||||
The first stage models can then be found in `models/first_stage_models/<model_spec>`
|
wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Pretrained LDMs
|
|
||||||
| Datset | Task | Model | FID | IS | Prec | Recall | Link | Comments
|
|
||||||
|---------------------------------|------|--------------|---------------|-----------------|------|------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------|
|
|
||||||
| CelebA-HQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0)| 5.11 (5.11) | 3.29 | 0.72 | 0.49 | https://ommer-lab.com/files/latent-diffusion/celeba.zip | |
|
|
||||||
| FFHQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 4.98 (4.98) | 4.50 (4.50) | 0.73 | 0.50 | https://ommer-lab.com/files/latent-diffusion/ffhq.zip | |
|
|
||||||
| LSUN-Churches | Unconditional Image Synthesis | LDM-KL-8 (400 DDIM steps, eta=0)| 4.02 (4.02) | 2.72 | 0.64 | 0.52 | https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip | |
|
|
||||||
| LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
|
|
||||||
| ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
|
|
||||||
| Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
|
|
||||||
| OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
|
|
||||||
| OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
|
|
||||||
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | |
|
|
||||||
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 |
|
|
||||||
|
|
||||||
|
|
||||||
### Get the models
|
|
||||||
|
|
||||||
The LDMs listed above can jointly be downloaded and extracted via
|
|
||||||
|
|
||||||
```shell script
|
|
||||||
bash scripts/download_models.sh
|
|
||||||
```
|
```
|
||||||
|
and sample with
|
||||||
The models can then be found in `models/ldm/<model_spec>`.
|
|
||||||
|
|
||||||
### Sampling with unconditional models
|
|
||||||
|
|
||||||
We provide a first script for sampling from our unconditional models. Start it via
|
|
||||||
|
|
||||||
```shell script
|
|
||||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python scripts/sample_diffusion.py -r models/ldm/<model_spec>/model.ckpt -l <logdir> -n <\#samples> --batch_size <batch_size> -c <\#ddim steps> -e <\#eta>
|
|
||||||
```
|
```
|
||||||
|
python scripts/txt2img.py --prompt "a virus monster is playing guitar, oil on canvas" --ddim_eta 0.0 --n_samples 4 --n_iter 4 --scale 5.0 --ddim_steps 50
|
||||||
|
```
|
||||||
|
This will save each sample individually as well as a grid of size `n_iter` x `n_samples` at the specified output location (default: `outputs/txt2img-samples`).
|
||||||
|
Quality, sampling speed and diversity are best controlled via the `scale`, `ddim_steps` and `ddim_eta` arguments.
|
||||||
|
As a rule of thumb, higher values of `scale` produce better samples at the cost of a reduced output diversity.
|
||||||
|
Furthermore, increasing `ddim_steps` generally also gives higher quality samples, but returns are diminishing for values > 250.
|
||||||
|
Fast sampling (i.e. low values of `ddim_steps`) while retaining good quality can be achieved by using `--ddim_eta 0.0`.
|
||||||
|
|
||||||
# Inpainting
|
#### Beyond 256²
|
||||||
|
|
||||||
|
For certain inputs, simply running the model in a convolutional fashion on larger features than it was trained on
|
||||||
|
can sometimes result in interesting results. To try it out, tune the `H` and `W` arguments (which will be integer-divided
|
||||||
|
by 8 in order to calculate the corresponding latent size), e.g. run
|
||||||
|
|
||||||
|
```
|
||||||
|
python scripts/txt2img.py --prompt "a sunset behind a mountain range, vector image" --ddim_eta 1.0 --n_samples 1 --n_iter 1 --H 384 --W 1024 --scale 5.0
|
||||||
|
```
|
||||||
|
to create a sample of size 384x1024. Note, however, that controllability is reduced compared to the 256x256 setting.
|
||||||
|
|
||||||
|
The example below was generated using the above command.
|
||||||
|
![text2img-figure-conv](assets/txt2img-convsample.png)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Inpainting
|
||||||
![inpainting](assets/inpainting.png)
|
![inpainting](assets/inpainting.png)
|
||||||
|
|
||||||
Download the pre-trained weights
|
Download the pre-trained weights
|
||||||
|
@ -107,6 +82,22 @@ python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inp
|
||||||
`indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
|
`indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
|
||||||
the examples provided in `data/inpainting_examples`.
|
the examples provided in `data/inpainting_examples`.
|
||||||
|
|
||||||
|
## Class-Conditional ImageNet
|
||||||
|
|
||||||
|
Available via a [notebook](scripts/latent_imagenet_diffusion.ipynb) [![][colab]][colab-cin].
|
||||||
|
![class-conditional](assets/birdhouse.png)
|
||||||
|
|
||||||
|
[colab]: <https://colab.research.google.com/assets/colab-badge.svg>
|
||||||
|
[colab-cin]: <https://colab.research.google.com/github/CompVis/latent-diffusion/blob/main/scripts/latent-imagenet-diffusion.ipynb>
|
||||||
|
|
||||||
|
|
||||||
|
## Unconditional Models
|
||||||
|
|
||||||
|
We also provide a script for sampling from unconditional LDMs (e.g. LSUN, FFHQ, ...). Start it via
|
||||||
|
|
||||||
|
```shell script
|
||||||
|
CUDA_VISIBLE_DEVICES=<GPU_ID> python scripts/sample_diffusion.py -r models/ldm/<model_spec>/model.ckpt -l <logdir> -n <\#samples> --batch_size <batch_size> -c <\#ddim steps> -e <\#eta>
|
||||||
|
```
|
||||||
|
|
||||||
# Train your own LDMs
|
# Train your own LDMs
|
||||||
|
|
||||||
|
@ -188,16 +179,72 @@ where ``<config_spec>`` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder,
|
||||||
`lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
|
`lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
|
||||||
`lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
|
`lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
|
||||||
|
|
||||||
|
# Model Zoo
|
||||||
|
|
||||||
|
## Pretrained Autoencoding Models
|
||||||
|
![rec2](assets/reconstruction2.png)
|
||||||
|
|
||||||
|
All models were trained until convergence (no further substantial improvement in rFID).
|
||||||
|
|
||||||
|
| Model | rFID vs val | train steps |PSNR | PSIM | Link | Comments
|
||||||
|
|-------------------------|------------|----------------|----------------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------|
|
||||||
|
| f=4, VQ (Z=8192, d=3) | 0.58 | 533066 | 27.43 +/- 4.26 | 0.53 +/- 0.21 | https://ommer-lab.com/files/latent-diffusion/vq-f4.zip | |
|
||||||
|
| f=4, VQ (Z=8192, d=3) | 1.06 | 658131 | 25.21 +/- 4.17 | 0.72 +/- 0.26 | https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1 | no attention |
|
||||||
|
| f=8, VQ (Z=16384, d=4) | 1.14 | 971043 | 23.07 +/- 3.99 | 1.17 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | |
|
||||||
|
| f=8, VQ (Z=256, d=4) | 1.49 | 1608649 | 22.35 +/- 3.81 | 1.26 +/- 0.37 | https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip |
|
||||||
|
| f=16, VQ (Z=16384, d=8) | 5.15 | 1101166 | 20.83 +/- 3.61 | 1.73 +/- 0.43 | https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1 | |
|
||||||
|
| | | | | | | |
|
||||||
|
| f=4, KL | 0.27 | 176991 | 27.53 +/- 4.54 | 0.55 +/- 0.24 | https://ommer-lab.com/files/latent-diffusion/kl-f4.zip | |
|
||||||
|
| f=8, KL | 0.90 | 246803 | 24.19 +/- 4.19 | 1.02 +/- 0.35 | https://ommer-lab.com/files/latent-diffusion/kl-f8.zip | |
|
||||||
|
| f=16, KL (d=16) | 0.87 | 442998 | 24.08 +/- 4.22 | 1.07 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/kl-f16.zip | |
|
||||||
|
| f=32, KL (d=64) | 2.04 | 406763 | 22.27 +/- 3.93 | 1.41 +/- 0.40 | https://ommer-lab.com/files/latent-diffusion/kl-f32.zip | |
|
||||||
|
|
||||||
|
### Get the models
|
||||||
|
|
||||||
|
Running the following script downloads und extracts all available pretrained autoencoding models.
|
||||||
|
```shell script
|
||||||
|
bash scripts/download_first_stages.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The first stage models can then be found in `models/first_stage_models/<model_spec>`
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Pretrained LDMs
|
||||||
|
| Datset | Task | Model | FID | IS | Prec | Recall | Link | Comments
|
||||||
|
|---------------------------------|------|--------------|---------------|-----------------|------|------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------|
|
||||||
|
| CelebA-HQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0)| 5.11 (5.11) | 3.29 | 0.72 | 0.49 | https://ommer-lab.com/files/latent-diffusion/celeba.zip | |
|
||||||
|
| FFHQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 4.98 (4.98) | 4.50 (4.50) | 0.73 | 0.50 | https://ommer-lab.com/files/latent-diffusion/ffhq.zip | |
|
||||||
|
| LSUN-Churches | Unconditional Image Synthesis | LDM-KL-8 (400 DDIM steps, eta=0)| 4.02 (4.02) | 2.72 | 0.64 | 0.52 | https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip | |
|
||||||
|
| LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
|
||||||
|
| ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
|
||||||
|
| Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
|
||||||
|
| OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
|
||||||
|
| OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
|
||||||
|
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | |
|
||||||
|
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 |
|
||||||
|
|
||||||
|
|
||||||
|
### Get the models
|
||||||
|
|
||||||
|
The LDMs listed above can jointly be downloaded and extracted via
|
||||||
|
|
||||||
|
```shell script
|
||||||
|
bash scripts/download_models.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The models can then be found in `models/ldm/<model_spec>`.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Coming Soon...
|
## Coming Soon...
|
||||||
|
|
||||||
* More inference scripts for conditional LDMs.
|
* More inference scripts for conditional LDMs.
|
||||||
* In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
|
* In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
|
||||||
* We will also release some further pretrained models.
|
|
||||||
|
|
||||||
|
|
||||||
## Comments
|
## Comments
|
||||||
|
|
||||||
- Our codebase for the diffusion models builds heavily on [OpenAI's codebase](https://github.com/openai/guided-diffusion)
|
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
||||||
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
|
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
|
||||||
Thanks for open-sourcing!
|
Thanks for open-sourcing!
|
||||||
|
|
||||||
|
|
BIN
assets/birdhouse.png
Normal file
BIN
assets/birdhouse.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 757 KiB |
BIN
assets/txt2img-convsample.png
Normal file
BIN
assets/txt2img-convsample.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 302 KiB |
BIN
assets/txt2img-preview.png
Normal file
BIN
assets/txt2img-preview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 MiB |
68
configs/latent-diffusion/cin256-v2.yaml
Normal file
68
configs/latent-diffusion/cin256-v2.yaml
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 0.0001
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: class_label
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 192
|
||||||
|
attention_resolutions:
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
num_heads: 1
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||||
|
params:
|
||||||
|
n_classes: 1001
|
||||||
|
embed_dim: 512
|
||||||
|
key: class_label
|
71
configs/latent-diffusion/txt2img-1p4B-eval.yaml
Normal file
71
configs/latent-diffusion/txt2img-1p4B-eval.yaml
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.012
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: true
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
1000
data/imagenet_clsidx_to_label.txt
Executable file
1000
data/imagenet_clsidx_to_label.txt
Executable file
File diff suppressed because it is too large
Load diff
|
@ -72,6 +72,9 @@ class DDIMSampler(object):
|
||||||
verbose=True,
|
verbose=True,
|
||||||
x_T=None,
|
x_T=None,
|
||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
|
@ -100,7 +103,9 @@ class DDIMSampler(object):
|
||||||
score_corrector=score_corrector,
|
score_corrector=score_corrector,
|
||||||
corrector_kwargs=corrector_kwargs,
|
corrector_kwargs=corrector_kwargs,
|
||||||
x_T=x_T,
|
x_T=x_T,
|
||||||
log_every_t=log_every_t
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
|
@ -109,7 +114,8 @@ class DDIMSampler(object):
|
||||||
x_T=None, ddim_use_original_steps=False,
|
x_T=None, ddim_use_original_steps=False,
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
|
@ -142,7 +148,9 @@ class DDIMSampler(object):
|
||||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
corrector_kwargs=corrector_kwargs)
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
img, pred_x0 = outs
|
img, pred_x0 = outs
|
||||||
if callback: callback(i)
|
if callback: callback(i)
|
||||||
if img_callback: img_callback(pred_x0, i)
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
@ -155,9 +163,16 @@ class DDIMSampler(object):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
e_t = self.model.apply_model(x, t, c)
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
|
||||||
|
if unconditional_guidance_scale > 1.:
|
||||||
|
assert unconditional_conditioning is not None
|
||||||
|
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
if score_corrector is not None:
|
if score_corrector is not None:
|
||||||
assert self.model.parameterization == "eps"
|
assert self.model.parameterization == "eps"
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
|
@ -455,7 +455,7 @@ class UNetModel(nn.Module):
|
||||||
num_classes=None,
|
num_classes=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
num_heads=1,
|
num_heads=-1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
use_scale_shift_norm=False,
|
use_scale_shift_norm=False,
|
||||||
|
@ -464,21 +464,28 @@ class UNetModel(nn.Module):
|
||||||
use_spatial_transformer=False, # custom transformer support
|
use_spatial_transformer=False, # custom transformer support
|
||||||
transformer_depth=1, # custom transformer support
|
transformer_depth=1, # custom transformer support
|
||||||
context_dim=None, # custom transformer support
|
context_dim=None, # custom transformer support
|
||||||
n_embed=None # custom support for prediction of discrete ids into codebook of first stage vq model
|
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||||
|
legacy=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
if context_dim is not None:
|
if context_dim is not None:
|
||||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||||
|
from omegaconf.listconfig import ListConfig
|
||||||
|
if type(context_dim) == ListConfig:
|
||||||
|
context_dim = list(context_dim)
|
||||||
|
|
||||||
if num_heads_upsample == -1:
|
if num_heads_upsample == -1:
|
||||||
num_heads_upsample = num_heads
|
num_heads_upsample = num_heads
|
||||||
|
|
||||||
|
if num_heads == -1:
|
||||||
|
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
|
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||||
|
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
|
@ -532,13 +539,20 @@ class UNetModel(nn.Module):
|
||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
if ds in attention_resolutions:
|
if ds in attention_resolutions:
|
||||||
dim_head = ch // num_heads
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_head_channels=num_head_channels,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
|
@ -572,7 +586,14 @@ class UNetModel(nn.Module):
|
||||||
ds *= 2
|
ds *= 2
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
dim_head = ch // num_heads
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
self.middle_block = TimestepEmbedSequential(
|
self.middle_block = TimestepEmbedSequential(
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
|
@ -586,7 +607,7 @@ class UNetModel(nn.Module):
|
||||||
ch,
|
ch,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_head_channels=num_head_channels,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
|
@ -619,13 +640,20 @@ class UNetModel(nn.Module):
|
||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = model_channels * mult
|
||||||
if ds in attention_resolutions:
|
if ds in attention_resolutions:
|
||||||
dim_head = ch // num_heads
|
if num_head_channels == -1:
|
||||||
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
num_heads=num_heads_upsample,
|
num_heads=num_heads_upsample,
|
||||||
num_head_channels=num_head_channels,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
|
@ -691,7 +719,6 @@ class UNetModel(nn.Module):
|
||||||
assert (y is not None) == (
|
assert (y is not None) == (
|
||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
assert timesteps is not None, 'need to implement no-timestep usage'
|
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
@ -710,14 +737,12 @@ class UNetModel(nn.Module):
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
#return self.out(h), self.id_predictor(h)
|
|
||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
else:
|
else:
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
class EncoderUNetModel(nn.Module):
|
class EncoderUNetModel(nn.Module):
|
||||||
# TODO: do we use it ?
|
|
||||||
"""
|
"""
|
||||||
The half UNet model with attention and timestep embedding.
|
The half UNet model with attention and timestep embedding.
|
||||||
For usage, see UNet.
|
For usage, see UNet.
|
||||||
|
|
167
ldm/modules/losses/vqperceptual.py
Normal file
167
ldm/modules/losses/vqperceptual.py
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import repeat
|
||||||
|
|
||||||
|
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
|
||||||
|
from taming.modules.losses.lpips import LPIPS
|
||||||
|
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||||
|
|
||||||
|
|
||||||
|
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
|
||||||
|
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
|
||||||
|
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
|
||||||
|
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
|
||||||
|
loss_real = (weights * loss_real).sum() / weights.sum()
|
||||||
|
loss_fake = (weights * loss_fake).sum() / weights.sum()
|
||||||
|
d_loss = 0.5 * (loss_real + loss_fake)
|
||||||
|
return d_loss
|
||||||
|
|
||||||
|
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
||||||
|
if global_step < threshold:
|
||||||
|
weight = value
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def measure_perplexity(predicted_indices, n_embed):
|
||||||
|
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
||||||
|
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
||||||
|
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
|
||||||
|
avg_probs = encodings.mean(0)
|
||||||
|
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
||||||
|
cluster_use = torch.sum(avg_probs > 0)
|
||||||
|
return perplexity, cluster_use
|
||||||
|
|
||||||
|
def l1(x, y):
|
||||||
|
return torch.abs(x-y)
|
||||||
|
|
||||||
|
|
||||||
|
def l2(x, y):
|
||||||
|
return torch.pow((x-y), 2)
|
||||||
|
|
||||||
|
|
||||||
|
class VQLPIPSWithDiscriminator(nn.Module):
|
||||||
|
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
||||||
|
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
||||||
|
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
||||||
|
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
|
||||||
|
pixel_loss="l1"):
|
||||||
|
super().__init__()
|
||||||
|
assert disc_loss in ["hinge", "vanilla"]
|
||||||
|
assert perceptual_loss in ["lpips", "clips", "dists"]
|
||||||
|
assert pixel_loss in ["l1", "l2"]
|
||||||
|
self.codebook_weight = codebook_weight
|
||||||
|
self.pixel_weight = pixelloss_weight
|
||||||
|
if perceptual_loss == "lpips":
|
||||||
|
print(f"{self.__class__.__name__}: Running with LPIPS.")
|
||||||
|
self.perceptual_loss = LPIPS().eval()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
|
||||||
|
self.perceptual_weight = perceptual_weight
|
||||||
|
|
||||||
|
if pixel_loss == "l1":
|
||||||
|
self.pixel_loss = l1
|
||||||
|
else:
|
||||||
|
self.pixel_loss = l2
|
||||||
|
|
||||||
|
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
||||||
|
n_layers=disc_num_layers,
|
||||||
|
use_actnorm=use_actnorm,
|
||||||
|
ndf=disc_ndf
|
||||||
|
).apply(weights_init)
|
||||||
|
self.discriminator_iter_start = disc_start
|
||||||
|
if disc_loss == "hinge":
|
||||||
|
self.disc_loss = hinge_d_loss
|
||||||
|
elif disc_loss == "vanilla":
|
||||||
|
self.disc_loss = vanilla_d_loss
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
||||||
|
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
||||||
|
self.disc_factor = disc_factor
|
||||||
|
self.discriminator_weight = disc_weight
|
||||||
|
self.disc_conditional = disc_conditional
|
||||||
|
self.n_classes = n_classes
|
||||||
|
|
||||||
|
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
||||||
|
if last_layer is not None:
|
||||||
|
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
||||||
|
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
||||||
|
else:
|
||||||
|
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
||||||
|
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
||||||
|
|
||||||
|
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
||||||
|
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
||||||
|
d_weight = d_weight * self.discriminator_weight
|
||||||
|
return d_weight
|
||||||
|
|
||||||
|
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
||||||
|
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
|
||||||
|
if not exists(codebook_loss):
|
||||||
|
codebook_loss = torch.tensor([0.]).to(inputs.device)
|
||||||
|
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
||||||
|
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||||
|
if self.perceptual_weight > 0:
|
||||||
|
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
||||||
|
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
||||||
|
else:
|
||||||
|
p_loss = torch.tensor([0.0])
|
||||||
|
|
||||||
|
nll_loss = rec_loss
|
||||||
|
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
||||||
|
nll_loss = torch.mean(nll_loss)
|
||||||
|
|
||||||
|
# now the GAN part
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# generator update
|
||||||
|
if cond is None:
|
||||||
|
assert not self.disc_conditional
|
||||||
|
logits_fake = self.discriminator(reconstructions.contiguous())
|
||||||
|
else:
|
||||||
|
assert self.disc_conditional
|
||||||
|
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
||||||
|
g_loss = -torch.mean(logits_fake)
|
||||||
|
|
||||||
|
try:
|
||||||
|
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
||||||
|
except RuntimeError:
|
||||||
|
assert not self.training
|
||||||
|
d_weight = torch.tensor(0.0)
|
||||||
|
|
||||||
|
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||||
|
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
||||||
|
|
||||||
|
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
||||||
|
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
||||||
|
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
||||||
|
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
||||||
|
"{}/p_loss".format(split): p_loss.detach().mean(),
|
||||||
|
"{}/d_weight".format(split): d_weight.detach(),
|
||||||
|
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
||||||
|
"{}/g_loss".format(split): g_loss.detach().mean(),
|
||||||
|
}
|
||||||
|
if predicted_indices is not None:
|
||||||
|
assert self.n_classes is not None
|
||||||
|
with torch.no_grad():
|
||||||
|
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
|
||||||
|
log[f"{split}/perplexity"] = perplexity
|
||||||
|
log[f"{split}/cluster_usage"] = cluster_usage
|
||||||
|
return loss, log
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# second pass for discriminator update
|
||||||
|
if cond is None:
|
||||||
|
logits_real = self.discriminator(inputs.contiguous().detach())
|
||||||
|
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
||||||
|
else:
|
||||||
|
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
||||||
|
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
||||||
|
|
||||||
|
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
||||||
|
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
||||||
|
|
||||||
|
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
||||||
|
"{}/logits_real".format(split): logits_real.detach().mean(),
|
||||||
|
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
||||||
|
}
|
||||||
|
return d_loss, log
|
|
@ -407,7 +407,7 @@ class AttentionLayers(nn.Module):
|
||||||
self.rotary_pos_emb = always(None)
|
self.rotary_pos_emb = always(None)
|
||||||
|
|
||||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||||
self.rel_pos = always(None)
|
self.rel_pos = None
|
||||||
|
|
||||||
self.pre_norm = pre_norm
|
self.pre_norm = pre_norm
|
||||||
|
|
||||||
|
|
429
scripts/latent_imagenet_diffusion.ipynb
Normal file
429
scripts/latent_imagenet_diffusion.ipynb
Normal file
File diff suppressed because one or more lines are too long
154
scripts/txt2img.py
Normal file
154
scripts/txt2img.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
import argparse, os, sys, glob
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from einops import rearrange
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default="a painting of a virus monster playing guitar",
|
||||||
|
help="the prompt to render"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--outdir",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="dir to write results to",
|
||||||
|
default="outputs/txt2img-samples"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_steps",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="number of ddim sampling steps",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_eta",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_iter",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="sample this often",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--H",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="image height, in pixel space",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--W",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="image width, in pixel space",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_samples",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="how many samples to produce for the given prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale",
|
||||||
|
type=float,
|
||||||
|
default=5.0,
|
||||||
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
|
)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
|
||||||
|
model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt") # TODO: check path
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
|
outpath = opt.outdir
|
||||||
|
|
||||||
|
prompt = opt.prompt
|
||||||
|
|
||||||
|
|
||||||
|
sample_path = os.path.join(outpath, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(sample_path))
|
||||||
|
|
||||||
|
all_samples=list()
|
||||||
|
with torch.no_grad():
|
||||||
|
with model.ema_scope():
|
||||||
|
uc = None
|
||||||
|
if opt.scale > 0:
|
||||||
|
uc = model.get_learned_conditioning(opt.n_samples * [""])
|
||||||
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
c = model.get_learned_conditioning(opt.n_samples * [prompt])
|
||||||
|
shape = [4, opt.H//8, opt.W//8]
|
||||||
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=opt.n_samples,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=opt.scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt.ddim_eta)
|
||||||
|
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
for x_sample in x_samples_ddim:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png"))
|
||||||
|
base_count += 1
|
||||||
|
all_samples.append(x_samples_ddim)
|
||||||
|
|
||||||
|
|
||||||
|
# additionally, save as grid
|
||||||
|
grid = torch.stack(all_samples, 0)
|
||||||
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
grid = make_grid(grid, nrow=opt.n_samples)
|
||||||
|
|
||||||
|
# to image
|
||||||
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))
|
||||||
|
|
||||||
|
print(f"Your samples are ready and waiting four you here: \n{outpath} \nEnjoy.")
|
Loading…
Reference in a new issue