stable-diffusion-finetune/scripts/latent_imagenet_diffusion.ipynb
2022-04-04 16:17:48 +02:00

4 MiB
Raw Blame History

Class-Conditional Synthesis with Latent Diffusion Models

Install all the requirements

In [ ]:
#@title Installation
!git clone https://github.com/CompVis/latent-diffusion.git
!git clone https://github.com/CompVis/taming-transformers
!pip install -e ./taming-transformers
!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops

import sys
sys.path.append(".")
sys.path.append('./taming-transformers')
from taming.models import vqgan
Cloning into 'latent-diffusion'...
remote: Enumerating objects: 992, done.
remote: Counting objects: 100% (695/695), done.
remote: Compressing objects: 100% (397/397), done.
remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297
Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.
Resolving deltas: 100% (510/510), done.
Cloning into 'taming-transformers'...
remote: Enumerating objects: 1335, done.
remote: Counting objects: 100% (525/525), done.
remote: Compressing objects: 100% (493/493), done.
remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810
Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.
Resolving deltas: 100% (267/267), done.
Obtaining file:///content/taming-transformers
Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from taming-transformers==0.0.1) (1.10.0+cu111)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from taming-transformers==0.0.1) (1.21.5)
Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from taming-transformers==0.0.1) (4.63.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->taming-transformers==0.0.1) (3.10.0.2)
Installing collected packages: taming-transformers
  Running setup.py develop for taming-transformers
Successfully installed taming-transformers-0.0.1
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.
arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.

Now, download the checkpoint (~1.7 GB). This will usually take 1-2 minutes.

In [ ]:
#@title Download
%cd latent-diffusion/ 

!mkdir -p models/ldm/cin256-v2/
!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt
/content/latent-diffusion
--2022-04-03 13:04:51--  https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt
Resolving ommer-lab.com (ommer-lab.com)... 141.84.41.65
Connecting to ommer-lab.com (ommer-lab.com)|141.84.41.65|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1827378153 (1.7G)
Saving to: models/ldm/cin256-v2/model.ckpt

models/ldm/cin256-v 100%[===================>]   1.70G  24.9MB/s    in 70s     

2022-04-03 13:06:02 (24.9 MB/s) - models/ldm/cin256-v2/model.ckpt saved [1827378153/1827378153]

Let's also check what type of GPU we've got.

In [ ]:
!nvidia-smi
Sun Apr  3 13:06:21 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   66C    P8    33W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Load it.

In [ ]:
#@title loading utils
import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config


def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model


def get_model():
    config = OmegaConf.load("configs/latent-diffusion/cin256-v2.yaml")  
    model = load_model_from_config(config, "models/ldm/cin256-v2/model.ckpt")
    return model
In [ ]:
from ldm.models.diffusion.ddim import DDIMSampler

model = get_model()
sampler = DDIMSampler(model)
Loading model from models/ldm/cin256-v2/model.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 400.92 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels

And go. Quality, sampling speed and diversity are best controlled via the scale, ddim_steps and ddim_eta variables. As a rule of thumb, higher values of scale produce better samples at the cost of a reduced output diversity. Furthermore, increasing ddim_steps generally also gives higher quality samples, but returns are diminishing for values > 250. Fast sampling (i e. low values of ddim_steps) while retaining good quality can be achieved by using ddim_eta = 0.0.

In [ ]:
import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid


classes = [25, 187, 448, 992]   # define classes to be sampled here
n_samples_per_class = 6

ddim_steps = 20
ddim_eta = 0.0
scale = 3.0   # for unconditional guidance


all_samples = list()

with torch.no_grad():
    with model.ema_scope():
        uc = model.get_learned_conditioning(
            {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}
            )
        
        for class_label in classes:
            print(f"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.")
            xc = torch.tensor(n_samples_per_class*[class_label])
            c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})
            
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                             conditioning=c,
                                             batch_size=n_samples_per_class,
                                             shape=[3, 64, 64],
                                             verbose=False,
                                             unconditional_guidance_scale=scale,
                                             unconditional_conditioning=uc, 
                                             eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                         min=0.0, max=1.0)
            all_samples.append(x_samples_ddim)


# display as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples_per_class)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8))
rendering 6 examples of class '25' in 20 steps and using s=3.00.
Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0
Running DDIM Sampling with 20 timesteps
DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00,  1.89s/it]
rendering 6 examples of class '187' in 20 steps and using s=3.00.
Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0
Running DDIM Sampling with 20 timesteps
DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00,  1.87s/it]
rendering 6 examples of class '448' in 20 steps and using s=3.00.
Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0
Running DDIM Sampling with 20 timesteps
DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00,  1.86s/it]
rendering 6 examples of class '992' in 20 steps and using s=3.00.
Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0
Running DDIM Sampling with 20 timesteps
DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00,  1.86s/it]
Out[ ]:
In [ ]: