add prior to sd notebook

This commit is contained in:
Justin Pinkney 2022-09-05 03:34:29 -04:00
parent 12dd21670b
commit c7504a6ec8
3 changed files with 311 additions and 2 deletions

2
.gitignore vendored
View File

@ -1,6 +1,6 @@
logs/
dump/
examples/
im-examples/
outputs/
flagged/
*.egg-info

309
examples/prior_2_sd.ipynb Normal file
View File

@ -0,0 +1,309 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "0533f618-f54c-4231-b79b-6fd3043696a0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jpinkney/miniconda3/envs/stable/lib/python3.10/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
" if not hasattr(tensorboard, \"__version__\") or LooseVersion(\n",
"/home/jpinkney/miniconda3/envs/stable/lib/python3.10/site-packages/torch/utils/tensorboard/__init__.py:6: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
" ) < LooseVersion(\"1.15\"):\n"
]
}
],
"source": [
"from dalle2_pytorch.train_configs import DiffusionPriorConfig\n",
"import json\n",
"import torch\n",
"import torch\n",
"from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter\n",
"from dalle2_pytorch.trainer import DiffusionPriorTrainer\n",
"import clip"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "619dd2aa-4cdb-43bf-b7cd-349826330020",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading model from /home/jpinkney/code/stable-diffusion/logs/2022-08-19T15-53-26_jp_pretraining_img/checkpoints/last.ckpt\n",
"Global Step: 87056\n",
"LatentDiffusion: Running in eps-prediction mode\n",
"DiffusionWrapper has 859.52 M params.\n",
"Keeping EMAs of 688.\n",
"making attention of type 'vanilla' with 512 in_channels\n",
"Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
"making attention of type 'vanilla' with 512 in_channels\n"
]
}
],
"source": [
"import argparse, os, sys, glob\n",
"import torch\n",
"import numpy as np\n",
"from omegaconf import OmegaConf\n",
"from PIL import Image\n",
"from tqdm import tqdm, trange\n",
"from itertools import islice\n",
"from einops import rearrange\n",
"from torchvision.utils import make_grid\n",
"import time\n",
"from pytorch_lightning import seed_everything\n",
"from torch import autocast\n",
"from contextlib import contextmanager, nullcontext\n",
"\n",
"from ldm.util import instantiate_from_config\n",
"from ldm.models.diffusion.ddim import DDIMSampler\n",
"from ldm.models.diffusion.plms import PLMSSampler\n",
"from scripts.image_variations import load_model_from_config\n",
"\n",
"device = \"cuda:0\"\n",
"\n",
"config = \"configs/stable-diffusion/sd-image-condition-finetune.yaml\"\n",
"ckpt = \"models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt\"\n",
"config = OmegaConf.load(config)\n",
"model = load_model_from_config(config, ckpt, device)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4181066b-1641-476f-aaca-ae49e6950dd2",
"metadata": {},
"outputs": [],
"source": [
"\n",
"# use 0.15.4\n",
"def load_prior(model_path):\n",
" prior_network = DiffusionPriorNetwork(\n",
" dim=768,\n",
" depth=24,\n",
" dim_head=64,\n",
" heads=32,\n",
" normformer=True,\n",
" attn_dropout=5e-2,\n",
" ff_dropout=5e-2,\n",
" num_time_embeds=1,\n",
" num_image_embeds=1,\n",
" num_text_embeds=1,\n",
" num_timesteps=1000,\n",
" ff_mult=4\n",
" )\n",
"\n",
" diffusion_prior = DiffusionPrior(\n",
" net=prior_network,\n",
" clip=OpenAIClipAdapter(\"ViT-L/14\"),\n",
" image_embed_dim=768,\n",
" timesteps=1000,\n",
" cond_drop_prob=0.1,\n",
" loss_type=\"l2\",\n",
" condition_on_text_encodings=True,\n",
" ).to(device)\n",
"\n",
" state_dict = torch.load(model_path, map_location='cpu')\n",
" if 'ema_model' in state_dict:\n",
" print('Loading EMA Model')\n",
" diffusion_prior.load_state_dict(state_dict['ema_model'], strict=True)\n",
" else:\n",
" print('Loading Standard Model')\n",
" diffusion_prior.load_state_dict(state_dict['model'], strict=False)\n",
" del state_dict\n",
" return diffusion_prior"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ee81a270-2b7d-4d8c-8cf6-031e982597e2",
"metadata": {},
"outputs": [],
"source": [
"from dalle2_pytorch.train_configs import DiffusionPriorConfig, TrainDiffusionPriorConfig\n",
"\n",
"def make_prior(\n",
" prior_config: DiffusionPriorConfig, checkpoint_path: str, device: str = None\n",
"):\n",
" # create model from config\n",
" diffusion_prior = prior_config.create()\n",
" state_dict = torch.load(checkpoint_path, map_location=\"cpu\")\n",
" diffusion_prior.load_state_dict(state_dict)\n",
" diffusion_prior.eval()\n",
" diffusion_prior.to(device)\n",
"\n",
" if device == \"cpu\":\n",
" diffusion_prior.float()\n",
" return diffusion_prior\n",
"\n",
"# load entire config\n",
"train_config = TrainDiffusionPriorConfig.from_json_path(\"../DALLE2-pytorch/pretrained/prior_config.json\")\n",
"prior_config = train_config.prior\n",
"\n",
"# load model\n",
"prior = make_prior(prior_config=prior_config, checkpoint_path=\"../DALLE2-pytorch/pretrained/latest.pth\", device=device)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "96d74e72-54d8-4529-a7c2-cfe5c0c8008e",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8e83104e232744f9952587e16c034668",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"sampling loop time step: 0%| | 0/64 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# tokenize the text\n",
"tokenized_text = clip.tokenize(\"still life in the style of Picasso\").to(device)\n",
"# predict an embedding, make sure to denormalise\n",
"predicted_embedding = prior.sample(tokenized_text, num_samples_per_batch=2, cond_scale=1.0)*prior.image_embed_scale"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "7680379b-01bb-41a8-884d-4a10fa887ce0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------\n",
"[4, 64, 64]\n",
"Data shape for DDIM sampling is (4, 4, 64, 64), eta 1.0\n",
"Running DDIM Sampling with 50 timesteps\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"DDIM Sampler: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00, 2.79it/s]\n"
]
}
],
"source": [
"plms = False\n",
"outdir = \"dump\"\n",
"n_samples = 4\n",
"n_rows = 0\n",
"precision = \"fp32\"\n",
"\n",
"ddim_steps = 50\n",
"scale = 3.0\n",
"ddim_eta = 1.0\n",
"\n",
"if plms:\n",
" sampler = PLMSSampler(model)\n",
"else:\n",
" sampler = DDIMSampler(model)\n",
"\n",
"os.makedirs(outdir, exist_ok=True)\n",
"outpath = outdir\n",
"\n",
"batch_size = n_samples\n",
"n_rows = n_rows if n_rows > 0 else batch_size\n",
"\n",
"\n",
"sample_path = os.path.join(outpath, \"samples\")\n",
"os.makedirs(sample_path, exist_ok=True)\n",
"base_count = len(os.listdir(sample_path))\n",
"grid_count = len(os.listdir(outpath)) - 1\n",
"\n",
"start_code = None\n",
"\n",
"# c = torch.rand(n_samples, 1, 768, device=device)\n",
"c = predicted_embedding.tile(n_samples, 1).unsqueeze(1)\n",
"\n",
"precision_scope = autocast if precision==\"autocast\" else nullcontext\n",
"with torch.no_grad():\n",
" with precision_scope(\"cuda\"):\n",
" with model.ema_scope():\n",
" tic = time.time()\n",
" # c = model.get_learned_conditioning(prompts)\n",
"\n",
" uc = None\n",
" if scale != 1.0:\n",
" uc = torch.zeros_like(c)\n",
" shape = [4, 512 // 8, 512 // 8]\n",
" print(\"--------\")\n",
" print(shape)\n",
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n",
" conditioning=c,\n",
" batch_size=n_samples,\n",
" shape=shape,\n",
" verbose=False,\n",
" unconditional_guidance_scale=scale,\n",
" unconditional_conditioning=uc,\n",
" eta=ddim_eta,\n",
" x_T=start_code)\n",
"\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
"\n",
" for x_sample in x_samples_ddim:\n",
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
" Image.fromarray(x_sample.astype(np.uint8)).save(\n",
" os.path.join(sample_path, f\"{base_count:05}.png\"))\n",
" base_count += 1\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "caef0607-dc6e-4862-99ed-15281c269a49",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 ('stable')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"vscode": {
"interpreter": {
"hash": "7b7b6e55edb8d6b4ec26da3e41ac48d31f242b54c90f284dae7273709056fff2"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -19,7 +19,7 @@ from ldm.util import instantiate_from_config
def load_model_from_config(config, ckpt, device, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
pl_sd = torch.load(ckpt, map_location=device)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]