stable-diffusion-finetune/scripts/latent_imagenet_diffusion.i...

429 lines
4.0 MiB
Plaintext
Raw Permalink Normal View History

2022-04-04 16:17:48 +02:00
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "latent-imagenet-diffusion.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Class-Conditional Synthesis with Latent Diffusion Models"
],
"metadata": {
"id": "NUmmV5ZvrPbP"
}
},
{
"cell_type": "markdown",
"source": [
"Install all the requirements"
],
"metadata": {
"id": "zh7u8gOx0ivw"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NHgUAp48qwoG",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "411d4df6-d91a-42d4-819e-9cf641c12248",
"cellView": "form"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'latent-diffusion'...\n",
"remote: Enumerating objects: 992, done.\u001B[K\n",
"remote: Counting objects: 100% (695/695), done.\u001B[K\n",
"remote: Compressing objects: 100% (397/397), done.\u001B[K\n",
"remote: Total 992 (delta 375), reused 564 (delta 253), pack-reused 297\u001B[K\n",
"Receiving objects: 100% (992/992), 30.78 MiB | 29.43 MiB/s, done.\n",
"Resolving deltas: 100% (510/510), done.\n",
"Cloning into 'taming-transformers'...\n",
"remote: Enumerating objects: 1335, done.\u001B[K\n",
"remote: Counting objects: 100% (525/525), done.\u001B[K\n",
"remote: Compressing objects: 100% (493/493), done.\u001B[K\n",
"remote: Total 1335 (delta 58), reused 481 (delta 30), pack-reused 810\u001B[K\n",
"Receiving objects: 100% (1335/1335), 412.35 MiB | 30.53 MiB/s, done.\n",
"Resolving deltas: 100% (267/267), done.\n",
"Obtaining file:///content/taming-transformers\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from taming-transformers==0.0.1) (1.10.0+cu111)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from taming-transformers==0.0.1) (1.21.5)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from taming-transformers==0.0.1) (4.63.0)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->taming-transformers==0.0.1) (3.10.0.2)\n",
"Installing collected packages: taming-transformers\n",
" Running setup.py develop for taming-transformers\n",
"Successfully installed taming-transformers-0.0.1\n",
"\u001B[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.8.0 requires tf-estimator-nightly==2.8.0.dev2021122109, which is not installed.\n",
"arviz 0.11.4 requires typing-extensions<4,>=3.7.4.3, but you have typing-extensions 4.1.1 which is incompatible.\u001B[0m\n"
]
}
],
"source": [
"#@title Installation\n",
"!git clone https://github.com/CompVis/latent-diffusion.git\n",
"!git clone https://github.com/CompVis/taming-transformers\n",
"!pip install -e ./taming-transformers\n",
"!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops\n",
"\n",
"import sys\n",
"sys.path.append(\".\")\n",
"sys.path.append('./taming-transformers')\n",
"from taming.models import vqgan "
]
},
{
"cell_type": "markdown",
"source": [
"Now, download the checkpoint (~1.7 GB). This will usually take 1-2 minutes."
],
"metadata": {
"id": "fNqCqQDoyZmq"
}
},
{
"cell_type": "code",
"source": [
"#@title Download\n",
"%cd latent-diffusion/ \n",
"\n",
"!mkdir -p models/ldm/cin256-v2/\n",
"!wget -O models/ldm/cin256-v2/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt "
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cNHvQBhzyXCI",
"outputId": "0a79e979-8484-4c62-96d9-7c79b1835162",
"cellView": "form"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/latent-diffusion\n",
"--2022-04-03 13:04:51-- https://ommer-lab.com/files/latent-diffusion/nitro/cin/model.ckpt\n",
"Resolving ommer-lab.com (ommer-lab.com)... 141.84.41.65\n",
"Connecting to ommer-lab.com (ommer-lab.com)|141.84.41.65|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 1827378153 (1.7G)\n",
"Saving to: models/ldm/cin256-v2/model.ckpt\n",
"\n",
"models/ldm/cin256-v 100%[===================>] 1.70G 24.9MB/s in 70s \n",
"\n",
"2022-04-03 13:06:02 (24.9 MB/s) - models/ldm/cin256-v2/model.ckpt saved [1827378153/1827378153]\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Let's also check what type of GPU we've got."
],
"metadata": {
"id": "ThxmCePqt1mt"
}
},
{
"cell_type": "code",
"source": [
"!nvidia-smi"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jbL2zJ7Pt7Jl",
"outputId": "c8242be9-dba2-4a9f-da44-a294a70bb449"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Sun Apr 3 13:06:21 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 66C P8 33W / 149W | 0MiB / 11441MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Load it."
],
"metadata": {
"id": "1tWAqdwk0Nrn"
}
},
{
"cell_type": "code",
"source": [
"#@title loading utils\n",
"import torch\n",
"from omegaconf import OmegaConf\n",
"\n",
"from ldm.util import instantiate_from_config\n",
"\n",
"\n",
"def load_model_from_config(config, ckpt):\n",
" print(f\"Loading model from {ckpt}\")\n",
" pl_sd = torch.load(ckpt)#, map_location=\"cpu\")\n",
" sd = pl_sd[\"state_dict\"]\n",
" model = instantiate_from_config(config.model)\n",
" m, u = model.load_state_dict(sd, strict=False)\n",
" model.cuda()\n",
" model.eval()\n",
" return model\n",
"\n",
"\n",
"def get_model():\n",
" config = OmegaConf.load(\"configs/latent-diffusion/cin256-v2.yaml\") \n",
" model = load_model_from_config(config, \"models/ldm/cin256-v2/model.ckpt\")\n",
" return model"
],
"metadata": {
"id": "fnGwQRhtyBhb",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from ldm.models.diffusion.ddim import DDIMSampler\n",
"\n",
"model = get_model()\n",
"sampler = DDIMSampler(model)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BPnyd-XUKbfE",
"outputId": "0fcd10e4-0df2-4ab9-cbf5-f08f4902c954"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Loading model from models/ldm/cin256-v2/model.ckpt\n",
"LatentDiffusion: Running in eps-prediction mode\n",
"DiffusionWrapper has 400.92 M params.\n",
"making attention of type 'vanilla' with 512 in_channels\n",
"Working with z of shape (1, 3, 64, 64) = 12288 dimensions.\n",
"making attention of type 'vanilla' with 512 in_channels\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"And go. Quality, sampling speed and diversity are best controlled via the `scale`, `ddim_steps` and `ddim_eta` variables. As a rule of thumb, higher values of `scale` produce better samples at the cost of a reduced output diversity. Furthermore, increasing `ddim_steps` generally also gives higher quality samples, but returns are diminishing for values > 250. Fast sampling (i e. low values of `ddim_steps`) while retaining good quality can be achieved by using `ddim_eta = 0.0`."
],
"metadata": {
"id": "iIEAhY8AhUrh"
}
},
{
"cell_type": "code",
"source": [
"import numpy as np \n",
"from PIL import Image\n",
"from einops import rearrange\n",
"from torchvision.utils import make_grid\n",
"\n",
"\n",
"classes = [25, 187, 448, 992] # define classes to be sampled here\n",
"n_samples_per_class = 6\n",
"\n",
"ddim_steps = 20\n",
"ddim_eta = 0.0\n",
"scale = 3.0 # for unconditional guidance\n",
"\n",
"\n",
"all_samples = list()\n",
"\n",
"with torch.no_grad():\n",
" with model.ema_scope():\n",
" uc = model.get_learned_conditioning(\n",
" {model.cond_stage_key: torch.tensor(n_samples_per_class*[1000]).to(model.device)}\n",
" )\n",
" \n",
" for class_label in classes:\n",
" print(f\"rendering {n_samples_per_class} examples of class '{class_label}' in {ddim_steps} steps and using s={scale:.2f}.\")\n",
" xc = torch.tensor(n_samples_per_class*[class_label])\n",
" c = model.get_learned_conditioning({model.cond_stage_key: xc.to(model.device)})\n",
" \n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n",
" conditioning=c,\n",
" batch_size=n_samples_per_class,\n",
" shape=[3, 64, 64],\n",
" verbose=False,\n",
" unconditional_guidance_scale=scale,\n",
" unconditional_conditioning=uc, \n",
" eta=ddim_eta)\n",
"\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, \n",
" min=0.0, max=1.0)\n",
" all_samples.append(x_samples_ddim)\n",
"\n",
"\n",
"# display as grid\n",
"grid = torch.stack(all_samples, 0)\n",
"grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n",
"grid = make_grid(grid, nrow=n_samples_per_class)\n",
"\n",
"# to image\n",
"grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
"Image.fromarray(grid.astype(np.uint8))"
],
"metadata": {
"id": "jcbqWX2Ytu9t",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "3b7adde0-d80e-4c01-82d2-bf988aee7455"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"rendering 6 examples of class '25' in 20 steps and using s=3.00.\n",
"Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0\n",
"Running DDIM Sampling with 20 timesteps\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00, 1.89s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"rendering 6 examples of class '187' in 20 steps and using s=3.00.\n",
"Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0\n",
"Running DDIM Sampling with 20 timesteps\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00, 1.87s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"rendering 6 examples of class '448' in 20 steps and using s=3.00.\n",
"Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0\n",
"Running DDIM Sampling with 20 timesteps\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00, 1.86s/it]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"rendering 6 examples of class '992' in 20 steps and using s=3.00.\n",
"Data shape for DDIM sampling is (6, 3, 64, 64), eta 0.0\n",
"Running DDIM Sampling with 20 timesteps\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"DDIM Sampler: 100%|██████████| 20/20 [00:37<00:00, 1.86s/it]\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<PIL.Image.Image image mode=RGB size=1550x1034 at 0x7FF8B6840F50>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAABg4AAAQKCAIAAAAXQRbAAAEAAElEQVR4nOz915Nt2Zbeh40xpll2m8yd7vhzyt57q+p617fR3WgAjWiAEKEgqSAhECAfJFKBBwYYwVAoQm+QieCDvMRQhKgAJQgSCME0TANNROMCbW/fvra8PaeOTZ+57XLTjKGHtXeegv6GmlWRJ3Obteaca61pfvP7xgT4PH2ePk+fp8/T5+nz9Hn6PH2ePk+fp8/T5+nz9Hn6PH2ePk+fp8/T5+nz9Hn6PH2ePk+fp8/T5+nz9Hn6PH2ePk+fp88m7P959YUDcMiaR7ltHRbD3FjLJrEaYlhVres8HhzsjkdDrSBWq9ViVjetTckaU9Wua32IUQDKPE1T45x3UazVolReZr72i1nd+sBAmbEhQjHIdgZZvVwCYDHKq9oL4yClotSxDYtFJcrcuLNXLZ1rwmR76/qNvUTpqgouqtFkJI5CHYvBqBxOSNPZybNBni1XSyBvsyRyRHJts4hNB6gO7k3qOhyfnO3vaYrx9HKxPSkMqrZrrNI2UdNZpYzSJNYgoRhrfBsZRBM2XccCJArI6MR0deNCk+dGoYTA2mgA8S4qA0ar6ONq5T99fF4Oh9akTePuXN9TQMvZbLCbmUSzC3madF0Acnmq5tNKkJPEVFWrNSGiDyCIibWgkBmcc6mxbdt1nbNpqrW1qe6WKw6xC8ygQSuK3LZeIGqGeVXfur3/7Mn85HSWFwnp1GMcZsCedaKSXM2nq6Zy29vlzYOtqgrONSJmXncH2yPAkGcJQ+hcq1QR0eQDG1375PEzdl1aDr/6tS8uZ8vDYz9b+cPjs0ePzhbT6pUX9+5ez5NMXIiCirSyVhVlUtedEAGBNZrZN65lQBScTZum8ghGSO3ubu3vjuu6U+kgH28xQ2JpsjWArtXaXFxeJgm1q/mjp08jh8cPjsyw3J+MQIxHWwwTHyhJ9O2DLYyxrUPT1nmmjU7rjsloiWE2XXhfJ6mpVzWCRPH5UOVJYo09PqlRJylJdC4hEEz+y//zvwaA//X//K8hkNJIgESIiIAAgEgoAsICICIQRfo3hEUQAYCIkJQiFI7MMYYom+cLQQQAkQiBlEJErRSu34H+FNifqv9BAAAijIDCLCACCCICyDFGZh+C9yGyKEJCUkRKKURERBaOMQJIiMwiIiDCINwfPDE6TYzVVmtNiNQXCgAAiRAEBQAR+xeg/7YwiMimqaB1BvtPIYsAMPSVAoJIhLR+D1EAhYX7jIgIQGQREAEWBgEg7GsFNqfFzQdF+oaprzsRAGBhEQABJCQk6d8Ukf68pLRRKIiE6/yICEhfw0iA6//Wh15njJkFBIQQrxpEYZH1dcH+FmAWRACBdf0gKkJEREJhFoF1XQOICAvDZ3KAz7/UF7YvaV8+FJHNTSVERESbo/A694Dre66/x5AQBAEYhJl5XRtISH1VAEiMMTIjoggQIhCiCBICEgAQEiH2F73P4NUlZub/yX/xvwSAhx/83aQpGuc++OPf06F+cnLx5OiiO188vX+OmX3jSzev723PpmdGu3fun779eN4C1i1XTdRG7YwSC5KSymxq81KsevDp06r2aaLv7G/nSQYxVC5EkjKhYZbsDotRQpm1aRrZ89llNRooAJjNu6OLubHqhZdub5fF+dHpvRdvmGEWs2LZ4f13HsQQVqvu6Hx2sZw7jGhUbvyt/eLk6GK8P37xC7dffO0OCdfdUrimEE+PVu++ffz0qFEhJsw5S6LQR946OPirf/0/27r7DbB5gfWPfuPv/P5v/ubR03MwKlVcZpBaGOWKAJoYCcU5nFZh2cUuImnDSntBbXVR5Dvb6d2Xt154de/W3euDYZGnKjEISqmkcISDYS6gVJFCtwAJwS1WgjHKR7/7zt//f/3gBAafHraPH8+BtXMxMpiECMUHQLO+KxQBCwJpRQBI3kdhtkYRAAuCAAjmFlOFCsLI4kHCr06ygbSTnMapVk6MUYpAojRNFzvv2pBqRmAFrElrSrxgkRthQJ2cLrp0Zz89uP3iF18fjEbvv/n2b/7jf/nmuXeTiU34crpcdNGUBgCcY6PJpgTMTeuRAEjFIAqwSK0IARL72LjgY8QYSosHJb1+I9vJaXeohla0jnmRGTvOy/FwWLTzWQgtkEQRDhBF7+1NqqbNClvP27qq062cRMWIZTkcDjICaruV6+oqdHmOCSWEiTLF+No1SIvV0+M3f/tfxSid89F5Zu8lYGpu3dmx2l5czL33hiNLbDrvhpO/+rffA4D/y3/+HylmL53EGBlc5Mn+cLX0nYtpnmgA9pIPUpsmgDp0XI7Hw/FkeXL24bsf3Xv11tHDDx9/enS5aIe5Mai1iqOtHICePDva2d0/uHvnO7/0J2bzlWN38fT88mIxGfHTRx+98+ZHD49nXmcg0HahCqwVJQSd940XVkoRIvBWbva28hi4yG23aLrG7QzSXOFi1trUDIfpqg13XrwmgBFw7+bBrRs3kiwlhUoppQ0LI0JwTIjKaEBBkAhECo0xro4iZIvEaBUjKGOFpW26vLC+7Zz3iFEpLQAxBGFQRhmVeFDlaCgRg+esxMXFhTCHwGQURN9UnYDWKVmjQVAZQlRKK0U0v1wCxnJQok41cT2d+ujPT8/btgXB+XSRFxYQTy9mJMECoeTD/Wtf+NpXgujJ7p5j0qnmgKCS8dYodDEszz59/837b//kt7//gxdfe+mbr77SreqD2zug3agcnZ9U5SQ5eno8ngy2r28fHc3amm/euHb+dDGbLhbN8otffaGtYPps+sHbP9maDO9/dP/Z6UUXQJDHZZJySGJUJPlQr2q/crw7LAVVMRpO9iaTye7jJ9NBmQ0Kruu4bJ22WtliUOS7u3uD0eDifDacbA0mxfGTYyB1cG1PGetcUGSQxc3OLpaL3ZvXg3dKJVpDDL4Y5L7zAqItSdi07ajzcshAXdNVlxdki52b16vTp9Vybq1mhjS34mPXtABMKrRdxCQPHMfjyWqx6lzbOEqSMLSKtJpfzmP0yaB47QtfOjta2IQgt7/0l/9nAPBr3/0qRHQcWaIxFAOXRZomlhTGGNq2Cz5yFKVVnqVGKSKFikgRiwhI4BhiiAh5YgdFopVSiJ33TdM1dVNXrvXCItYaa02WGqVUYizHCAxkdJJoRWo4yLMsFSSlCCLHGAgVADILAJGCGGPwgUEiR45BKdRWm0QTUfAeMQoCE81XdUDvQg3GJUOlcgwCidYoLDGmxiq0Eg2qJLRdqJ1voyZtySRGB+AQgQQJVVN5RQZCgOiWy3q6bBbLmpmt1VmZDQfDIs3KLEdUViuBsFp2nXeLqgLE2XQ5r2pllIioJLFlsrMzSjVEiaQBxUfvus6TVoxoU40gbeeMsWlqu8CMIs65tkWtEFEYYoyEKER9tccQPcg4TxUIKOUgutASRwJhTQJYr9q6rl1gVJQaEzzEwIW2CWpgdD4wRa3JKsxykyS6aZ0xRkQ3tbCXdtGGIKlSmUn6EVXoQtsGRLJGt42LMexsbxd5cXj49Hy+MIlJbAJBgo8mMUAgQRQRkdZK5UWqAEIQBlJKEak0sZnVIII6XswWnQuRhVAprUUwxuB9ABTsx1lKaW20MSDiXVu3rTZ6e7xFZICFUZQCTUpZY40OPnRtF4J33nNkEUgTgyzMjBIjswsxhIhKaQGOISGExPy/f/uPAODu94aKQIgRGVBAUPrR4Xr8iST9mGo95sN+yIwI6wH2evyMAIJXk275zPy7H+iu/18P+z7zGbyarGM/FN28BfDZU2wOhoj4bx7+6ryfGV2vh+lXQ2MEBAK1Pjn2o20BeV4C+cwP2bxydYr+UFfv4/rrKP1HZF03CCSbygKU54eQfnyOAIiwHs9gP7xdzylgPXjux/gggMDS15ZcZen5UB4AoT/780p/XtR/Y7bTn5TWr69/rvO8LtF6IN6fZj1HwOeXYXOZ1pPGvg77Ml5dx+cXcj3FkPW1kM/MfdZDfhRBQBQhBhJAARIgQBRAQAIiAARaX+z1rUEItD4NrXPGAAIiIAwgwOs/+2mKgMB6egb9RFa
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"source": [
""
],
"metadata": {
"id": "92QkRfm0e6K0"
},
"execution_count": null,
"outputs": []
}
]
}