Img condition (#1)
* update reqs * add image variations * update readme
This commit is contained in:
parent
693e713c3e
commit
7e3956ef74
10 changed files with 455 additions and 281 deletions
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
logs/
|
||||||
|
dump/
|
||||||
|
examples/
|
||||||
|
outputs/
|
||||||
|
flagged/
|
||||||
|
*.egg-info
|
||||||
|
__pycache__
|
281
README.md
281
README.md
|
@ -1,278 +1,13 @@
|
||||||
# Latent Diffusion Models
|
# Experiments with Stable Diffusion
|
||||||
[arXiv](https://arxiv.org/abs/2112.10752) | [BibTeX](#bibtex)
|
|
||||||
|
|
||||||
<p align="center">
|
## Image variations
|
||||||
<img src=assets/results.gif />
|
|
||||||
</p>
|
|
||||||
|
|
||||||
|
[![](assets/img-vars.jpg)](https://twitter.com/Buntworthy/status/1561703483316781057)
|
||||||
|
|
||||||
|
_TODO describe in more detail_
|
||||||
|
|
||||||
[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)<br/>
|
- Get model from huggingface hub [lambdalabs/stable-diffusion-image-conditioned](https://huggingface.co/lambdalabs/stable-diffusion-image-conditioned/blob/main/sd-clip-vit-l14-img-embed_ema_only.ckpt)
|
||||||
[Robin Rombach](https://github.com/rromb)\*,
|
- Put model in `models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt`
|
||||||
[Andreas Blattmann](https://github.com/ablattmann)\*,
|
- Run `scripts/image_variations.py` or `scripts/gradio_variations.py`
|
||||||
[Dominik Lorenz](https://github.com/qp-qp)\,
|
|
||||||
[Patrick Esser](https://github.com/pesser),
|
|
||||||
[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
|
|
||||||
\* equal contribution
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src=assets/modelfigure.png />
|
|
||||||
</p>
|
|
||||||
|
|
||||||
## News
|
|
||||||
### April 2022
|
|
||||||
- Thanks to [Katherine Crowson](https://github.com/crowsonkb), classifier-free guidance received a ~2x speedup and the [PLMS sampler](https://arxiv.org/abs/2202.09778) is available. See also [this PR](https://github.com/CompVis/latent-diffusion/pull/51).
|
|
||||||
|
|
||||||
- Our 1.45B [latent diffusion LAION model](#text-to-image) was integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/multimodalart/latentdiffusion)
|
|
||||||
|
|
||||||
- More pre-trained LDMs are available:
|
|
||||||
- A 1.45B [model](#text-to-image) trained on the [LAION-400M](https://arxiv.org/abs/2111.02114) database.
|
|
||||||
- A class-conditional model on ImageNet, achieving a FID of 3.6 when using [classifier-free guidance](https://openreview.net/pdf?id=qw8AKxfYbI) Available via a [colab notebook](https://colab.research.google.com/github/CompVis/latent-diffusion/blob/main/scripts/latent_imagenet_diffusion.ipynb) [![][colab]][colab-cin].
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
A suitable [conda](https://conda.io/) environment named `ldm` can be created
|
|
||||||
and activated with:
|
|
||||||
|
|
||||||
```
|
|
||||||
conda env create -f environment.yaml
|
|
||||||
conda activate ldm
|
|
||||||
```
|
|
||||||
|
|
||||||
# Pretrained Models
|
|
||||||
A general list of all available checkpoints is available in via our [model zoo](#model-zoo).
|
|
||||||
If you use any of these models in your work, we are always happy to receive a [citation](#bibtex).
|
|
||||||
|
|
||||||
## Text-to-Image
|
|
||||||
![text2img-figure](assets/txt2img-preview.png)
|
|
||||||
|
|
||||||
|
|
||||||
Download the pre-trained weights (5.7GB)
|
|
||||||
```
|
|
||||||
mkdir -p models/ldm/text2img-large/
|
|
||||||
wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt
|
|
||||||
```
|
|
||||||
and sample with
|
|
||||||
```
|
|
||||||
python scripts/txt2img.py --prompt "a virus monster is playing guitar, oil on canvas" --ddim_eta 0.0 --n_samples 4 --n_iter 4 --scale 5.0 --ddim_steps 50
|
|
||||||
```
|
|
||||||
This will save each sample individually as well as a grid of size `n_iter` x `n_samples` at the specified output location (default: `outputs/txt2img-samples`).
|
|
||||||
Quality, sampling speed and diversity are best controlled via the `scale`, `ddim_steps` and `ddim_eta` arguments.
|
|
||||||
As a rule of thumb, higher values of `scale` produce better samples at the cost of a reduced output diversity.
|
|
||||||
Furthermore, increasing `ddim_steps` generally also gives higher quality samples, but returns are diminishing for values > 250.
|
|
||||||
Fast sampling (i.e. low values of `ddim_steps`) while retaining good quality can be achieved by using `--ddim_eta 0.0`.
|
|
||||||
Faster sampling (i.e. even lower values of `ddim_steps`) while retaining good quality can be achieved by using `--ddim_eta 0.0` and `--plms` (see [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778)).
|
|
||||||
|
|
||||||
#### Beyond 256²
|
|
||||||
|
|
||||||
For certain inputs, simply running the model in a convolutional fashion on larger features than it was trained on
|
|
||||||
can sometimes result in interesting results. To try it out, tune the `H` and `W` arguments (which will be integer-divided
|
|
||||||
by 8 in order to calculate the corresponding latent size), e.g. run
|
|
||||||
|
|
||||||
```
|
|
||||||
python scripts/txt2img.py --prompt "a sunset behind a mountain range, vector image" --ddim_eta 1.0 --n_samples 1 --n_iter 1 --H 384 --W 1024 --scale 5.0
|
|
||||||
```
|
|
||||||
to create a sample of size 384x1024. Note, however, that controllability is reduced compared to the 256x256 setting.
|
|
||||||
|
|
||||||
The example below was generated using the above command.
|
|
||||||
![text2img-figure-conv](assets/txt2img-convsample.png)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Inpainting
|
|
||||||
![inpainting](assets/inpainting.png)
|
|
||||||
|
|
||||||
Download the pre-trained weights
|
|
||||||
```
|
|
||||||
wget -O models/ldm/inpainting_big/last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1
|
|
||||||
```
|
|
||||||
|
|
||||||
and sample with
|
|
||||||
```
|
|
||||||
python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inpainting_results
|
|
||||||
```
|
|
||||||
`indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
|
|
||||||
the examples provided in `data/inpainting_examples`.
|
|
||||||
|
|
||||||
## Class-Conditional ImageNet
|
|
||||||
|
|
||||||
Available via a [notebook](scripts/latent_imagenet_diffusion.ipynb) [![][colab]][colab-cin].
|
|
||||||
![class-conditional](assets/birdhouse.png)
|
|
||||||
|
|
||||||
[colab]: <https://colab.research.google.com/assets/colab-badge.svg>
|
|
||||||
[colab-cin]: <https://colab.research.google.com/github/CompVis/latent-diffusion/blob/main/scripts/latent_imagenet_diffusion.ipynb>
|
|
||||||
|
|
||||||
|
|
||||||
## Unconditional Models
|
|
||||||
|
|
||||||
We also provide a script for sampling from unconditional LDMs (e.g. LSUN, FFHQ, ...). Start it via
|
|
||||||
|
|
||||||
```shell script
|
|
||||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python scripts/sample_diffusion.py -r models/ldm/<model_spec>/model.ckpt -l <logdir> -n <\#samples> --batch_size <batch_size> -c <\#ddim steps> -e <\#eta>
|
|
||||||
```
|
|
||||||
|
|
||||||
# Train your own LDMs
|
|
||||||
|
|
||||||
## Data preparation
|
|
||||||
|
|
||||||
### Faces
|
|
||||||
For downloading the CelebA-HQ and FFHQ datasets, proceed as described in the [taming-transformers](https://github.com/CompVis/taming-transformers#celeba-hq)
|
|
||||||
repository.
|
|
||||||
|
|
||||||
### LSUN
|
|
||||||
|
|
||||||
The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
|
|
||||||
We performed a custom split into training and validation images, and provide the corresponding filenames
|
|
||||||
at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip).
|
|
||||||
After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
|
|
||||||
also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
|
|
||||||
|
|
||||||
### ImageNet
|
|
||||||
The code will try to download (through [Academic
|
|
||||||
Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
|
|
||||||
is used. However, since ImageNet is quite large, this requires a lot of disk
|
|
||||||
space and time. If you already have ImageNet on your disk, you can speed things
|
|
||||||
up by putting the data into
|
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
|
|
||||||
`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
|
|
||||||
of `train`/`validation`. It should have the following structure:
|
|
||||||
|
|
||||||
```
|
|
||||||
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
|
|
||||||
├── n01440764
|
|
||||||
│ ├── n01440764_10026.JPEG
|
|
||||||
│ ├── n01440764_10027.JPEG
|
|
||||||
│ ├── ...
|
|
||||||
├── n01443537
|
|
||||||
│ ├── n01443537_10007.JPEG
|
|
||||||
│ ├── n01443537_10014.JPEG
|
|
||||||
│ ├── ...
|
|
||||||
├── ...
|
|
||||||
```
|
|
||||||
|
|
||||||
If you haven't extracted the data, you can also place
|
|
||||||
`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
|
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
|
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
|
|
||||||
extracted into above structure without downloading it again. Note that this
|
|
||||||
will only happen if neither a folder
|
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
|
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
|
|
||||||
if you want to force running the dataset preparation again.
|
|
||||||
|
|
||||||
|
|
||||||
## Model Training
|
|
||||||
|
|
||||||
Logs and checkpoints for trained models are saved to `logs/<START_DATE_AND_TIME>_<config_spec>`.
|
|
||||||
|
|
||||||
### Training autoencoder models
|
|
||||||
|
|
||||||
Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
|
|
||||||
Training can be started by running
|
|
||||||
```
|
|
||||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,
|
|
||||||
```
|
|
||||||
where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder_kl_16x16x16`(f=16, d=16),
|
|
||||||
`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
|
|
||||||
|
|
||||||
For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
|
|
||||||
repository.
|
|
||||||
|
|
||||||
### Training LDMs
|
|
||||||
|
|
||||||
In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets.
|
|
||||||
Training can be started by running
|
|
||||||
|
|
||||||
```shell script
|
|
||||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,
|
|
||||||
```
|
|
||||||
|
|
||||||
where ``<config_spec>`` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),`ffhq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
|
|
||||||
`lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
|
|
||||||
`lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
|
|
||||||
|
|
||||||
# Model Zoo
|
|
||||||
|
|
||||||
## Pretrained Autoencoding Models
|
|
||||||
![rec2](assets/reconstruction2.png)
|
|
||||||
|
|
||||||
All models were trained until convergence (no further substantial improvement in rFID).
|
|
||||||
|
|
||||||
| Model | rFID vs val | train steps |PSNR | PSIM | Link | Comments
|
|
||||||
|-------------------------|------------|----------------|----------------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------|
|
|
||||||
| f=4, VQ (Z=8192, d=3) | 0.58 | 533066 | 27.43 +/- 4.26 | 0.53 +/- 0.21 | https://ommer-lab.com/files/latent-diffusion/vq-f4.zip | |
|
|
||||||
| f=4, VQ (Z=8192, d=3) | 1.06 | 658131 | 25.21 +/- 4.17 | 0.72 +/- 0.26 | https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1 | no attention |
|
|
||||||
| f=8, VQ (Z=16384, d=4) | 1.14 | 971043 | 23.07 +/- 3.99 | 1.17 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | |
|
|
||||||
| f=8, VQ (Z=256, d=4) | 1.49 | 1608649 | 22.35 +/- 3.81 | 1.26 +/- 0.37 | https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip |
|
|
||||||
| f=16, VQ (Z=16384, d=8) | 5.15 | 1101166 | 20.83 +/- 3.61 | 1.73 +/- 0.43 | https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1 | |
|
|
||||||
| | | | | | | |
|
|
||||||
| f=4, KL | 0.27 | 176991 | 27.53 +/- 4.54 | 0.55 +/- 0.24 | https://ommer-lab.com/files/latent-diffusion/kl-f4.zip | |
|
|
||||||
| f=8, KL | 0.90 | 246803 | 24.19 +/- 4.19 | 1.02 +/- 0.35 | https://ommer-lab.com/files/latent-diffusion/kl-f8.zip | |
|
|
||||||
| f=16, KL (d=16) | 0.87 | 442998 | 24.08 +/- 4.22 | 1.07 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/kl-f16.zip | |
|
|
||||||
| f=32, KL (d=64) | 2.04 | 406763 | 22.27 +/- 3.93 | 1.41 +/- 0.40 | https://ommer-lab.com/files/latent-diffusion/kl-f32.zip | |
|
|
||||||
|
|
||||||
### Get the models
|
|
||||||
|
|
||||||
Running the following script downloads und extracts all available pretrained autoencoding models.
|
|
||||||
```shell script
|
|
||||||
bash scripts/download_first_stages.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
The first stage models can then be found in `models/first_stage_models/<model_spec>`
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Pretrained LDMs
|
|
||||||
| Datset | Task | Model | FID | IS | Prec | Recall | Link | Comments
|
|
||||||
|---------------------------------|------|--------------|---------------|-----------------|------|------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------|
|
|
||||||
| CelebA-HQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0)| 5.11 (5.11) | 3.29 | 0.72 | 0.49 | https://ommer-lab.com/files/latent-diffusion/celeba.zip | |
|
|
||||||
| FFHQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 4.98 (4.98) | 4.50 (4.50) | 0.73 | 0.50 | https://ommer-lab.com/files/latent-diffusion/ffhq.zip | |
|
|
||||||
| LSUN-Churches | Unconditional Image Synthesis | LDM-KL-8 (400 DDIM steps, eta=0)| 4.02 (4.02) | 2.72 | 0.64 | 0.52 | https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip | |
|
|
||||||
| LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
|
|
||||||
| ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
|
|
||||||
| Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
|
|
||||||
| OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
|
|
||||||
| OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
|
|
||||||
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | |
|
|
||||||
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 |
|
|
||||||
|
|
||||||
|
|
||||||
### Get the models
|
|
||||||
|
|
||||||
The LDMs listed above can jointly be downloaded and extracted via
|
|
||||||
|
|
||||||
```shell script
|
|
||||||
bash scripts/download_models.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
The models can then be found in `models/ldm/<model_spec>`.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Coming Soon...
|
|
||||||
|
|
||||||
* More inference scripts for conditional LDMs.
|
|
||||||
* In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
|
|
||||||
|
|
||||||
## Comments
|
|
||||||
|
|
||||||
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
|
||||||
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
|
|
||||||
Thanks for open-sourcing!
|
|
||||||
|
|
||||||
- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
|
|
||||||
|
|
||||||
|
|
||||||
## BibTeX
|
|
||||||
|
|
||||||
```
|
|
||||||
@misc{rombach2021highresolution,
|
|
||||||
title={High-Resolution Image Synthesis with Latent Diffusion Models},
|
|
||||||
author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
|
|
||||||
year={2021},
|
|
||||||
eprint={2112.10752},
|
|
||||||
archivePrefix={arXiv},
|
|
||||||
primaryClass={cs.CV}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
Trained by [Justin Pinkney](https://www.justinpinkney.com) ([@Buntworthy](https://twitter.com/Buntworthy)) at [Lambda](https://lambdalabs.com/)
|
BIN
assets/img-vars.jpg
Normal file
BIN
assets/img-vars.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 443 KiB |
134
configs/stable-diffusion/sd-image-condition-finetune.yaml
Normal file
134
configs/stable-diffusion/sd-image-condition-finetune.yaml
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "jpg"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: ldm.data.laion.WebDataModuleFromConfig
|
||||||
|
params:
|
||||||
|
tar_base: "/mnt/data_rome/laion/improved_aesthetics_6plus/ims"
|
||||||
|
batch_size: 6
|
||||||
|
num_workers: 4
|
||||||
|
multinode: True
|
||||||
|
min_size: 256
|
||||||
|
train:
|
||||||
|
shards: '{00000..01209}.tar'
|
||||||
|
shuffle: 10000
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.RandomCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
# NOTE use enough shards to avoid empty validation loops in workers
|
||||||
|
validation:
|
||||||
|
shards: '{00000..00008}.tar -'
|
||||||
|
shuffle: 0
|
||||||
|
image_key: jpg
|
||||||
|
image_transforms:
|
||||||
|
- target: torchvision.transforms.Resize
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
interpolation: 3
|
||||||
|
- target: torchvision.transforms.CenterCrop
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
find_unused_parameters: false
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 5000
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: True
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 8
|
||||||
|
unconditional_guidance_scale: 3.0
|
||||||
|
unconditional_guidance_label: [""]
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000 # really sorry
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
|
@ -523,7 +523,7 @@ class LatentDiffusion(DDPM):
|
||||||
self.instantiate_cond_stage(cond_stage_config)
|
self.instantiate_cond_stage(cond_stage_config)
|
||||||
self.cond_stage_forward = cond_stage_forward
|
self.cond_stage_forward = cond_stage_forward
|
||||||
self.clip_denoised = False
|
self.clip_denoised = False
|
||||||
self.bbox_tokenizer = None
|
self.bbox_tokenizer = None
|
||||||
|
|
||||||
self.restarted_from_ckpt = False
|
self.restarted_from_ckpt = False
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
|
@ -904,7 +904,7 @@ class LatentDiffusion(DDPM):
|
||||||
|
|
||||||
if hasattr(self, "split_input_params"):
|
if hasattr(self, "split_input_params"):
|
||||||
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
||||||
assert not return_ids
|
assert not return_ids
|
||||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||||
|
|
||||||
|
@ -1343,7 +1343,9 @@ class LatentDiffusion(DDPM):
|
||||||
log["samples_x0_quantized"] = x_samples
|
log["samples_x0_quantized"] = x_samples
|
||||||
|
|
||||||
if unconditional_guidance_scale > 1.0:
|
if unconditional_guidance_scale > 1.0:
|
||||||
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
# uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||||
|
# FIXME
|
||||||
|
uc = torch.zeros_like(c)
|
||||||
with ema_scope("Sampling with classifier-free guidance"):
|
with ema_scope("Sampling with classifier-free guidance"):
|
||||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
|
||||||
ddim_steps=ddim_steps, eta=ddim_eta,
|
ddim_steps=ddim_steps, eta=ddim_eta,
|
||||||
|
|
|
@ -2,9 +2,11 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import kornia
|
||||||
|
|
||||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
|
import clip
|
||||||
|
|
||||||
|
|
||||||
class AbstractEncoder(nn.Module):
|
class AbstractEncoder(nn.Module):
|
||||||
|
@ -170,6 +172,42 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
return self(text)
|
return self(text)
|
||||||
|
|
||||||
|
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||||
|
"""
|
||||||
|
Uses the CLIP image encoder.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model='ViT-L/14',
|
||||||
|
jit=False,
|
||||||
|
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
|
antialias=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# Expects inputs in the range -1, 1
|
||||||
|
x = kornia.geometry.resize(x, (224, 224),
|
||||||
|
interpolation='bicubic',align_corners=True,
|
||||||
|
antialias=self.antialias)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# renormalize according to clip
|
||||||
|
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x is assumed to be in range [-1,1]
|
||||||
|
return self.model.encode_image(self.preprocess(x)).float()
|
||||||
|
|
||||||
|
def encode(self, im):
|
||||||
|
return self(im).unsqueeze(1)
|
||||||
|
|
||||||
class SpatialRescaler(nn.Module):
|
class SpatialRescaler(nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
25
main.py
25
main.py
|
@ -21,7 +21,7 @@ from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
MULTINODE_HACKS = True
|
MULTINODE_HACKS = False
|
||||||
|
|
||||||
|
|
||||||
def get_parser(**parser_kwargs):
|
def get_parser(**parser_kwargs):
|
||||||
|
@ -36,6 +36,13 @@ def get_parser(**parser_kwargs):
|
||||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||||
|
parser.add_argument(
|
||||||
|
"--finetune_from",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default="",
|
||||||
|
help="path to checkpoint to load model state from"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-n",
|
"-n",
|
||||||
"--name",
|
"--name",
|
||||||
|
@ -644,6 +651,20 @@ if __name__ == "__main__":
|
||||||
# model
|
# model
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
|
|
||||||
|
if not opt.finetune_from == "":
|
||||||
|
print(f"Attempting to load state from {opt.finetune_from}")
|
||||||
|
old_state = torch.load(opt.finetune_from, map_location="cpu")
|
||||||
|
if "state_dict" in old_state:
|
||||||
|
print(f"Found nested key 'state_dict' in checkpoint, loading this instead")
|
||||||
|
old_state = old_state["state_dict"]
|
||||||
|
m, u = model.load_state_dict(old_state, strict=False)
|
||||||
|
if len(m) > 0:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
# trainer and callbacks
|
# trainer and callbacks
|
||||||
trainer_kwargs = dict()
|
trainer_kwargs = dict()
|
||||||
|
|
||||||
|
@ -666,7 +687,7 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
default_logger_cfg = default_logger_cfgs["testtube"]
|
default_logger_cfg = default_logger_cfgs["wandb"]
|
||||||
if "logger" in lightning_config:
|
if "logger" in lightning_config:
|
||||||
logger_cfg = lightning_config.logger
|
logger_cfg = lightning_config.logger
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,17 +1,20 @@
|
||||||
albumentations==0.4.3
|
albumentations==0.4.3
|
||||||
opencv-python
|
opencv-python==4.5.5.64
|
||||||
pudb==2019.2
|
pudb==2019.2
|
||||||
imageio==2.9.0
|
imageio==2.9.0
|
||||||
imageio-ffmpeg==0.4.2
|
imageio-ffmpeg==0.4.2
|
||||||
pytorch-lightning==1.4.2
|
pytorch-lightning==1.4.2
|
||||||
torchmetrics==0.6
|
|
||||||
omegaconf==2.1.1
|
omegaconf==2.1.1
|
||||||
test-tube>=0.7.5
|
test-tube>=0.7.5
|
||||||
streamlit>=0.73.1
|
streamlit>=0.73.1
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
torch-fidelity==0.3.0
|
torch-fidelity==0.3.0
|
||||||
transformers==4.19.2
|
transformers
|
||||||
|
kornia==0.6
|
||||||
webdataset==0.2.5
|
webdataset==0.2.5
|
||||||
|
torchmetrics==0.6.0
|
||||||
|
fire==0.4.0
|
||||||
|
gradio==3.2
|
||||||
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
-e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
-e .
|
-e .
|
||||||
|
|
112
scripts/gradio_variations.py
Normal file
112
scripts/gradio_variations.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import gradio as gr
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from torch import autocast
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from scripts.image_variations import load_model_from_config
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, ddim_eta):
|
||||||
|
precision_scope = autocast if precision=="autocast" else nullcontext
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
c = model.get_learned_conditioning(input_im).tile(n_samples,1,1)
|
||||||
|
|
||||||
|
if scale != 1.0:
|
||||||
|
uc = torch.zeros_like(c)
|
||||||
|
else:
|
||||||
|
uc = None
|
||||||
|
|
||||||
|
shape = [4, h // 8, w // 8]
|
||||||
|
samples_ddim, _ = sampler.sample(S=ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=n_samples,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=ddim_eta,
|
||||||
|
x_T=None)
|
||||||
|
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
model,
|
||||||
|
device,
|
||||||
|
input_im,
|
||||||
|
scale=3.0,
|
||||||
|
n_samples=4,
|
||||||
|
plms=True,
|
||||||
|
ddim_steps=50,
|
||||||
|
ddim_eta=1.0,
|
||||||
|
precision="fp32",
|
||||||
|
h=512,
|
||||||
|
w=512,
|
||||||
|
):
|
||||||
|
|
||||||
|
input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device)
|
||||||
|
input_im = input_im*2-1
|
||||||
|
|
||||||
|
if plms:
|
||||||
|
sampler = PLMSSampler(model)
|
||||||
|
ddim_eta = 0.0
|
||||||
|
else:
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
x_samples_ddim = sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, ddim_eta)
|
||||||
|
output_ims = []
|
||||||
|
for x_sample in x_samples_ddim:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
output_ims.append(Image.fromarray(x_sample.astype(np.uint8)))
|
||||||
|
return output_ims
|
||||||
|
|
||||||
|
|
||||||
|
def run_demo(
|
||||||
|
device_idx=0,
|
||||||
|
ckpt="models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt",
|
||||||
|
config="configs/stable-diffusion/sd-image-condition-finetune.yaml",
|
||||||
|
):
|
||||||
|
|
||||||
|
device = f"cuda:{device_idx}"
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
model = load_model_from_config(config, ckpt, device=device)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
gr.Image(),
|
||||||
|
gr.Slider(0, 25, value=3, step=1, label="cfg scale"),
|
||||||
|
gr.Slider(1, 4, value=1, step=1, label="Number images"),
|
||||||
|
gr.Checkbox(True, label="plms"),
|
||||||
|
gr.Slider(5, 250, value=25, step=5, label="steps"),
|
||||||
|
]
|
||||||
|
output = gr.Gallery(label="Generated variations")
|
||||||
|
output.style(height="auto", grid=2)
|
||||||
|
|
||||||
|
fn_with_model = partial(main, model, device)
|
||||||
|
fn_with_model.__name__ = "fn_with_model"
|
||||||
|
|
||||||
|
demo = gr.Interface(
|
||||||
|
fn=fn_with_model,
|
||||||
|
title="Stable Diffusion Image Variations",
|
||||||
|
description="Generate variations on an input image using a fine-tuned version of Stable Diffision",
|
||||||
|
article="TODO",
|
||||||
|
inputs=inputs,
|
||||||
|
outputs=output,
|
||||||
|
)
|
||||||
|
# demo.queue()
|
||||||
|
demo.launch(share=False, server_name="0.0.0.0")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(run_demo)
|
122
scripts/image_variations.py
Normal file
122
scripts/image_variations.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
from io import BytesIO
|
||||||
|
import os
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from torch import autocast
|
||||||
|
from torchvision import transforms
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, device, verbose=False):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_im(im_path):
|
||||||
|
if im_path.startswith("http"):
|
||||||
|
response = requests.get(im_path)
|
||||||
|
response.raise_for_status()
|
||||||
|
im = Image.open(BytesIO(response.content))
|
||||||
|
else:
|
||||||
|
im = Image.open(im_path).convert("RGB")
|
||||||
|
tforms = transforms.Compose([
|
||||||
|
transforms.Resize(224),
|
||||||
|
transforms.CenterCrop((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
])
|
||||||
|
inp = tforms(im).unsqueeze(0)
|
||||||
|
return inp*2-1
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, ddim_eta):
|
||||||
|
precision_scope = autocast if precision=="autocast" else nullcontext
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
c = model.get_learned_conditioning(input_im).tile(n_samples,1,1)
|
||||||
|
|
||||||
|
if scale != 1.0:
|
||||||
|
uc = torch.zeros_like(c)
|
||||||
|
else:
|
||||||
|
uc = None
|
||||||
|
|
||||||
|
shape = [4, h // 8, w // 8]
|
||||||
|
samples_ddim, _ = sampler.sample(S=ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=n_samples,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=ddim_eta,
|
||||||
|
x_T=None)
|
||||||
|
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
def main(
|
||||||
|
im_path="data/example_conditioning/superresolution/sample_0.jpg",
|
||||||
|
ckpt="models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt",
|
||||||
|
config="configs/stable-diffusion/sd-image-condition-finetune.yaml",
|
||||||
|
outpath="im_variations",
|
||||||
|
scale=3.0,
|
||||||
|
h=512,
|
||||||
|
w=512,
|
||||||
|
n_samples=4,
|
||||||
|
precision="fp32",
|
||||||
|
plms=True,
|
||||||
|
ddim_steps=50,
|
||||||
|
ddim_eta=1.0,
|
||||||
|
device_idx=0,
|
||||||
|
):
|
||||||
|
|
||||||
|
device = f"cuda:{device_idx}"
|
||||||
|
input_im = load_im(im_path).to(device)
|
||||||
|
config = OmegaConf.load(config)
|
||||||
|
model = load_model_from_config(config, ckpt, device=device)
|
||||||
|
|
||||||
|
if plms:
|
||||||
|
sampler = PLMSSampler(model)
|
||||||
|
ddim_eta = 0.0
|
||||||
|
else:
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
os.makedirs(outpath, exist_ok=True)
|
||||||
|
|
||||||
|
sample_path = os.path.join(outpath, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(sample_path))
|
||||||
|
|
||||||
|
x_samples_ddim = sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, ddim_eta)
|
||||||
|
for x_sample in x_samples_ddim:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
|
base_count += 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
Loading…
Reference in a new issue