stable-diffusion-finetune/examples/prior_2_sd.ipynb

316 lines
364 KiB
Text
Raw Normal View History

2022-09-05 09:34:29 +02:00
{
"cells": [
{
"cell_type": "code",
2022-09-05 10:30:24 +02:00
"execution_count": null,
2022-09-05 09:34:29 +02:00
"id": "0533f618-f54c-4231-b79b-6fd3043696a0",
"metadata": {},
2022-09-05 10:30:24 +02:00
"outputs": [],
2022-09-05 09:34:29 +02:00
"source": [
"from dalle2_pytorch.train_configs import DiffusionPriorConfig\n",
"import json\n",
"import torch\n",
"import torch\n",
"from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter\n",
"from dalle2_pytorch.trainer import DiffusionPriorTrainer\n",
"import clip"
]
},
{
"cell_type": "code",
2022-09-05 10:30:24 +02:00
"execution_count": 3,
2022-09-05 09:34:29 +02:00
"id": "619dd2aa-4cdb-43bf-b7cd-349826330020",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-09-05 10:30:24 +02:00
"Loading model from ../models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt\n",
2022-09-05 09:34:29 +02:00
"LatentDiffusion: Running in eps-prediction mode\n",
"DiffusionWrapper has 859.52 M params.\n",
"Keeping EMAs of 688.\n",
"making attention of type 'vanilla' with 512 in_channels\n",
"Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
"making attention of type 'vanilla' with 512 in_channels\n"
]
}
],
"source": [
"import argparse, os, sys, glob\n",
"import torch\n",
"import numpy as np\n",
"from omegaconf import OmegaConf\n",
"from PIL import Image\n",
"from tqdm import tqdm, trange\n",
"from itertools import islice\n",
"from einops import rearrange\n",
"from torchvision.utils import make_grid\n",
"import time\n",
"from pytorch_lightning import seed_everything\n",
"from torch import autocast\n",
"from contextlib import contextmanager, nullcontext\n",
"\n",
"from ldm.util import instantiate_from_config\n",
"from ldm.models.diffusion.ddim import DDIMSampler\n",
"from ldm.models.diffusion.plms import PLMSSampler\n",
"from scripts.image_variations import load_model_from_config\n",
"\n",
"device = \"cuda:0\"\n",
"\n",
2022-09-05 10:30:24 +02:00
"config = \"../configs/stable-diffusion/sd-image-condition-finetune.yaml\"\n",
"ckpt = \"../models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt\"\n",
2022-09-05 09:34:29 +02:00
"config = OmegaConf.load(config)\n",
"model = load_model_from_config(config, ckpt, device)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4181066b-1641-476f-aaca-ae49e6950dd2",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# use 0.15.4\n",
"def load_prior(model_path):\n",
" prior_network = DiffusionPriorNetwork(\n",
" dim=768,\n",
" depth=24,\n",
" dim_head=64,\n",
" heads=32,\n",
" normformer=True,\n",
" attn_dropout=5e-2,\n",
" ff_dropout=5e-2,\n",
" num_time_embeds=1,\n",
" num_image_embeds=1,\n",
" num_text_embeds=1,\n",
" num_timesteps=1000,\n",
" ff_mult=4\n",
" )\n",
"\n",
" diffusion_prior = DiffusionPrior(\n",
" net=prior_network,\n",
" clip=OpenAIClipAdapter(\"ViT-L/14\"),\n",
" image_embed_dim=768,\n",
" timesteps=1000,\n",
" cond_drop_prob=0.1,\n",
" loss_type=\"l2\",\n",
" condition_on_text_encodings=True,\n",
" ).to(device)\n",
"\n",
" state_dict = torch.load(model_path, map_location='cpu')\n",
" if 'ema_model' in state_dict:\n",
" print('Loading EMA Model')\n",
" diffusion_prior.load_state_dict(state_dict['ema_model'], strict=True)\n",
" else:\n",
" print('Loading Standard Model')\n",
" diffusion_prior.load_state_dict(state_dict['model'], strict=False)\n",
" del state_dict\n",
" return diffusion_prior"
]
},
{
"cell_type": "code",
2022-09-05 10:30:24 +02:00
"execution_count": 5,
2022-09-05 09:34:29 +02:00
"id": "ee81a270-2b7d-4d8c-8cf6-031e982597e2",
"metadata": {},
"outputs": [],
"source": [
"from dalle2_pytorch.train_configs import DiffusionPriorConfig, TrainDiffusionPriorConfig\n",
"\n",
"def make_prior(\n",
" prior_config: DiffusionPriorConfig, checkpoint_path: str, device: str = None\n",
"):\n",
" # create model from config\n",
" diffusion_prior = prior_config.create()\n",
" state_dict = torch.load(checkpoint_path, map_location=\"cpu\")\n",
" diffusion_prior.load_state_dict(state_dict)\n",
" diffusion_prior.eval()\n",
" diffusion_prior.to(device)\n",
"\n",
" if device == \"cpu\":\n",
" diffusion_prior.float()\n",
" return diffusion_prior\n",
"\n",
"# load entire config\n",
2022-09-05 10:30:24 +02:00
"train_config = TrainDiffusionPriorConfig.from_json_path(\"../../DALLE2-pytorch/pretrained/prior_config.json\")\n",
2022-09-05 09:34:29 +02:00
"prior_config = train_config.prior\n",
"\n",
"# load model\n",
2022-09-05 10:30:24 +02:00
"prior = make_prior(prior_config=prior_config, checkpoint_path=\"../../DALLE2-pytorch/pretrained/latest.pth\", device=device)"
2022-09-05 09:34:29 +02:00
]
},
{
"cell_type": "code",
2022-09-05 10:30:24 +02:00
"execution_count": 18,
2022-09-05 09:34:29 +02:00
"id": "96d74e72-54d8-4529-a7c2-cfe5c0c8008e",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
2022-09-05 10:30:24 +02:00
"model_id": "3ecb3ec443554c2495af586d9f516dc9",
2022-09-05 09:34:29 +02:00
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"sampling loop time step: 0%| | 0/64 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# tokenize the text\n",
2022-09-05 10:30:24 +02:00
"tokenized_text = clip.tokenize(\"A watercolour painting of a moutain\").to(device)\n",
2022-09-05 09:34:29 +02:00
"# predict an embedding, make sure to denormalise\n",
"predicted_embedding = prior.sample(tokenized_text, num_samples_per_batch=2, cond_scale=1.0)*prior.image_embed_scale"
]
},
{
"cell_type": "code",
2022-09-05 10:30:24 +02:00
"execution_count": 20,
2022-09-05 09:34:29 +02:00
"id": "7680379b-01bb-41a8-884d-4a10fa887ce0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-09-05 10:30:24 +02:00
"Data shape for PLMS sampling is (4, 4, 64, 64)\n",
"Running PLMS Sampling with 50 timesteps\n"
2022-09-05 09:34:29 +02:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2022-09-05 10:30:24 +02:00
"PLMS Sampler: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:18<00:00, 2.74it/s]\n"
2022-09-05 09:34:29 +02:00
]
}
],
"source": [
2022-09-05 10:30:24 +02:00
"plms = True\n",
"outdir = \"prior2sd\"\n",
2022-09-05 09:34:29 +02:00
"n_samples = 4\n",
"n_rows = 0\n",
"precision = \"fp32\"\n",
"\n",
"ddim_steps = 50\n",
"scale = 3.0\n",
2022-09-05 10:30:24 +02:00
"ddim_eta = 0.0\n",
2022-09-05 09:34:29 +02:00
"\n",
"if plms:\n",
" sampler = PLMSSampler(model)\n",
"else:\n",
" sampler = DDIMSampler(model)\n",
"\n",
"os.makedirs(outdir, exist_ok=True)\n",
"outpath = outdir\n",
"\n",
"batch_size = n_samples\n",
"n_rows = n_rows if n_rows > 0 else batch_size\n",
"\n",
"\n",
"sample_path = os.path.join(outpath, \"samples\")\n",
"os.makedirs(sample_path, exist_ok=True)\n",
"base_count = len(os.listdir(sample_path))\n",
"grid_count = len(os.listdir(outpath)) - 1\n",
"\n",
"start_code = None\n",
"\n",
"# c = torch.rand(n_samples, 1, 768, device=device)\n",
"c = predicted_embedding.tile(n_samples, 1).unsqueeze(1)\n",
"\n",
"precision_scope = autocast if precision==\"autocast\" else nullcontext\n",
"with torch.no_grad():\n",
" with precision_scope(\"cuda\"):\n",
" with model.ema_scope():\n",
" tic = time.time()\n",
" # c = model.get_learned_conditioning(prompts)\n",
"\n",
" uc = None\n",
" if scale != 1.0:\n",
" uc = torch.zeros_like(c)\n",
" shape = [4, 512 // 8, 512 // 8]\n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n",
" conditioning=c,\n",
" batch_size=n_samples,\n",
" shape=shape,\n",
" verbose=False,\n",
" unconditional_guidance_scale=scale,\n",
" unconditional_conditioning=uc,\n",
" eta=ddim_eta,\n",
" x_T=start_code)\n",
"\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
"\n",
" for x_sample in x_samples_ddim:\n",
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
" Image.fromarray(x_sample.astype(np.uint8)).save(\n",
" os.path.join(sample_path, f\"{base_count:05}.png\"))\n",
" base_count += 1\n"
]
},
{
"cell_type": "code",
2022-09-05 10:30:24 +02:00
"execution_count": 22,
2022-09-05 09:34:29 +02:00
"id": "caef0607-dc6e-4862-99ed-15281c269a49",
"metadata": {},
2022-09-05 10:30:24 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIACAIAAAB7GkOtAAEAAElEQVR4nHT923psuY40ikWATFWt/j/7/a/9Dr7yW9m7a0o5CMS+AEAyVcvqXrOkzHHgAQwEDgT5//l//7/+85//8/r6a37NYRMgSUUE8IT+rx////2s/+87HhiGAQQIgRBAQJIAkgBQn4D5BdEX4Xyb30sCkBfur8HzVHBfj/2jfI/EunK/VarbVf/Ww0F0I04b8Ltdglivxf2+X/dWxz4aC4C623LfREHXx/12gnmT6q7rqd3gGgL2H7obL0EhhKBASGv5z8/6+YYCEeEuD3eXh0LDjBF61vPnD9Zaf/6sf/53rMfc5Yu+jEZBYv6PhA0SoAijFL4eEb4WCSO+Xq8xbA6bc/p6jA4t9z9j/Ez+/PWXz7kGOQySGyRZaIIvkmPYMA4bAojhgYC9ZT9L6xEQ0/D1ZV8jqBgm+gM4abEciBVY+vqJ//nzvP48fwVfblM2l+x5FiIACRKgCICwwfEKhYQQYIQN0DycpEhgeIRLpMxow0gzWkSQJBmKYZa/K8KIcIcEas5XyAmSI8Xu9deMcBvjr//zn/E17fWFOeff/xn/8zdeX5zG1xfmC2OEEeQWBZhRIqkjoPUtVP/7LZq1ziiJl/C1ZG+J/hB91QL5fNh5639blZ8/zHXA//7tRwM/hP5fzzsray/Yf9+8B4PYQFOQowQhYf+HPTBiI0wixUfHd+/uYeC/21erjtmAuux3t3sg7uX573FQXXnw5l8js1ELe/4K607bGqE+776bf9pHkgIjqBjrx9Y//7H1P3+N/8f/+fs///nP62tQioBNTMCWNAmJoQJohEQLl1YoYOAUI4xgJHKBKbrdsP3uPasE6gpeXc5BJ3oy7sFQi7X6AY2fKdBHz+geSJ7/1fXs8cw2lMTqDFX+zS3m3U5I3K+rEddZVfsxAFBdrqeyb+y35urcE6MzOwKvjv2eTN3z2fpVJcX5JIVau5kS3SK1QsC1niV3CZA8XI/T3X/eWk+83/5+0yOeRcgEBUEpGkYkG2aqeaEgwYCIGIRCAEGRNFLPYjgYNH9hWfy85s+kXkM0GSCFQpIDosYwmsJgkIa9FEAAgGLBQYqIUepoGTWHhYISsMzcIND88Vgun5JEmkwyKF9X74sQARgRiFigURjDAlQgFBIj+0CTBCWqsMZgwGgSJJmNIi1GBDycECiDxXLBSYMRDpDxs2AAY/38hFzv9/j68ogXpPm8/vM/iDUwCUxRxigNAEbyqkSrJk5kjk8LxOZHLSWJY0pd8kE18qP67NxBXmSrOEh+uZfkWTTqVdQtqlaB+zPdkPgBogKs1sdNXm7dcq39XlCff7by2ov6fkk0ovaKbY1JqZaqlOt6E8HSCtWvC7o/Vd7pCD/bmZi2yWtzNe6H/Buea8F0I7QV+8U3iTOByQol2L/V5u9J7iGRRFClSLnBRokadIA0J2UQLIRwGWERiDEjjIC0oIEtGETEytk30oQoGGKLQK4aFPQFdIknmvyqP8v5aE1cwnmPNFvJ9+AS+4NSJTX6yZxq3KoJPTdboI9y7qWztXm/db+rpK2FjdX+D0JUw3JPLo/CaJ5w3Z/qrPqOm4idl37yLF4t2esxF7pa/ehTI0kesfx5+3oinMB6/zzPD0JjDIT0ftbz1vvBWlqP3m+upbVMQckADGNEPTNkZoMpfNnfSO1AgaQNmtk0vgbnYEQICq0RLi3Sh2lOGyYwIEcAHqCRHAODoxEotVVIjBBCpgBJ45g2B0wAFeGAJElBEDTCJDI8/K0Ywggb4Y8EhARPnp5LwkgjXRERtAFiGAOmfCUoGREEKHGMMchS20FaqhCzS2/XtMDMCEUEFDBamhsCRIoUtH4e/wlB/9j4n79XPPOv/wgYf/8PyCHpNcEB9VhvqlD/NEScpXRQrqkpG42Vn4iHZjbxaEmshbI/zCs2FaobtAWzV8ch+3thNxpHgVkL7V7tbYjf2LXXY5w1lTd0n3Yz1OQmra5Ccbb20maJe1U3BQOBUDT5Kj9Ev4E8Ur1V2Faju4/c6/4TqHdHBIDGvvIa1WptvRetHeoZTT91DfiNUa3hubXYrZ3uhpD3VJ3GbnXcPRKtPDEBBEwwQcACZyoNY0waATEICYE03ZMAs4kGDEwStKlAtaQ1qTUt7ZuoWwL277zgW9fIcg/HNiy29H1owVJt/Qi2E4VnGLARegOxDma3OhbL53FfuE2HJin92G1PgHYUCPtxR3XhXkTbEEmx3o/bl55VfvWwRjbV/RZbfjq2BEnuHr5Cy2NBsZ7n/f553j8DJncG/Hn7n2+8H8bicjxPPA89ElZz9KKH25jko6a/GVPRYtLMaAQMohZCDE7G48BjesCgEQgEQShKWeU8DgyjgZF6LRABLnFJT+SitTnMDGYJpmr0d0rkEE0y2ljhgRnE4xHhIiUmr0dLULL7SJsWkhxBkuXkQtlWRgGyYdW3Qq4Qonw0ZpYXhwMiRKOEMjQgAz3HTBIwOP154wknQh4wywHm4PzifEta7hNfg38PmylRtSo2BhyhKF1Q0nnxo1vG2Zh2lp3O7YDthdIIQiVQokljivBev1vOjpHfMn3I18cCP7L7ef3tlkggu5fzvvFc8OnOapWHgs5LZ209ubmaWpWqcaCu3tqzX99rMSGgR3J/kmx1jzva8tZ+3nlMc/gzMVt33t3cbT+DuyGlOn9rhOab3NT5UsZ3a89QHTHJ/xaJUVqXbZDt+TUYLFwzH6Ri2FDxTYJI4nfYwnX/GdCjmM6k9fX6HJE9ArxVW//WfrqWrHrrHtmjL3iYzjFRe9XU6HCDawcMPhpWw7mhepOs8jcel9zxzO5u6hI6HMfU+X4zsP6Se832WDQX+bBOPnTdHUBoQSw5T/QJly+tpYik1svd3+9YCyFwDJDLuR49P3qW1sJaCJeHDRqtVkHy4AJcQBLLehMEyowAjWYsN0uag8MAhY1gvOdXzCFyCXAPUAgFSQwjzEAT7ZIxYSEccEFGQmMazW0YBxiAAnLAgZWkK4QQXcCY4YhwxRCWzAIAjKmPjQiS1os2kRmmpkUkaYgwaz7Cs4Z7fiXJMEqUkM6lAISIemy42bBabVCIIfhSLI/HU5Lm0PPoWVzrZaYVEe9YD4lhL7xmk7sDGO3Sv8jrB5psiT/MdLOVjZgt/ZdLKfUikMoPn2rkGNLXqit688Gu2qxvEfmEoMtltBt7CCxa6OvGDdrEJ8bg+m9rjWTF+yXXat+ciBkT25b0Hp7yabChersLeHu39i+ph3oNopZxBe12XzdiM1dM96jhtxtW46DWLvVG7ccmhF5Duduy+8bPsM3tjmicucVXgBmPjlLqEB05a4Vq5BwJQrSt0GoEjVgkmSvLcu31RTrvq5d+6Nkct5QS3q3cw4d+FlPSeV3QUrl9TTmNB5TP67fGRNm53B2/xZO3qXm//6D29dcZ+vPrURH7P7Yl5Ej2vV5TmZ15aoLyG+3/bedcM4m9ms/ikSLcfb3jeeQLEWZcCl9PRMRaKHSmRbiHlms9XE4PAwCYar4JMQ1D2vFntbnRyAFy5kR5BA0RYQaEoGWIMTFfNoxjkHTFUgQ0zIbxxQqsIhWJIIQEMgZCRNTiQfTKBREJhKSJSZGNQQXBL2CCjEBEhKXNbMaRvZG5n1FvPkHChsEWQNBosKPlWkbQQ9IilVSfAGRkMPVAUPlos7YHjQwj5JKHv329XQA4x3+G0SAzKiK0YthbjjHm199mjjm2azTf3oh8ZAqF42pqo0NuUOAGXUuFwPaXnAvRILSlvWe4BawEu33uZylfK+Zup3bQ+vxCHIjYbznB7avlh9ueZ/QTcE2KPt+02837Pyy5ylWJ7vGlMS6mfx6/R/Is9npJM7wDCodfHwipJh17/UKW39hyPfkChq0Be7B3+KKJpu0GfkRMG23Tli0ik076cggCjcsCWAulGpzGgWCcQhCwbqIkWsua9f/fs6dIwN40vCZAwI6dovuAhsLL39NNq27chCH/aT/PHrwChttLfs3W5WAq+f2Yi/z0ygD4dXe3//7PfdU92R8PxSWnl3l
"text/plain": [
"<PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512>"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Image.open(\"prior2sd/samples/00000.png\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52ef982f-00e5-4c3a-b7f9-9a2cb51d6dec",
"metadata": {},
2022-09-05 09:34:29 +02:00
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
2022-09-05 10:30:24 +02:00
"display_name": "Python 3 (ipykernel)",
2022-09-05 09:34:29 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"vscode": {
"interpreter": {
"hash": "7b7b6e55edb8d6b4ec26da3e41ac48d31f242b54c90f284dae7273709056fff2"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}