add prior to sd notebook
This commit is contained in:
parent
12dd21670b
commit
c7504a6ec8
3 changed files with 311 additions and 2 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,6 +1,6 @@
|
||||||
logs/
|
logs/
|
||||||
dump/
|
dump/
|
||||||
examples/
|
im-examples/
|
||||||
outputs/
|
outputs/
|
||||||
flagged/
|
flagged/
|
||||||
*.egg-info
|
*.egg-info
|
||||||
|
|
309
examples/prior_2_sd.ipynb
Normal file
309
examples/prior_2_sd.ipynb
Normal file
|
@ -0,0 +1,309 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "0533f618-f54c-4231-b79b-6fd3043696a0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/jpinkney/miniconda3/envs/stable/lib/python3.10/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
|
||||||
|
" if not hasattr(tensorboard, \"__version__\") or LooseVersion(\n",
|
||||||
|
"/home/jpinkney/miniconda3/envs/stable/lib/python3.10/site-packages/torch/utils/tensorboard/__init__.py:6: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
|
||||||
|
" ) < LooseVersion(\"1.15\"):\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from dalle2_pytorch.train_configs import DiffusionPriorConfig\n",
|
||||||
|
"import json\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch\n",
|
||||||
|
"from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter\n",
|
||||||
|
"from dalle2_pytorch.trainer import DiffusionPriorTrainer\n",
|
||||||
|
"import clip"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "619dd2aa-4cdb-43bf-b7cd-349826330020",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Loading model from /home/jpinkney/code/stable-diffusion/logs/2022-08-19T15-53-26_jp_pretraining_img/checkpoints/last.ckpt\n",
|
||||||
|
"Global Step: 87056\n",
|
||||||
|
"LatentDiffusion: Running in eps-prediction mode\n",
|
||||||
|
"DiffusionWrapper has 859.52 M params.\n",
|
||||||
|
"Keeping EMAs of 688.\n",
|
||||||
|
"making attention of type 'vanilla' with 512 in_channels\n",
|
||||||
|
"Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
|
||||||
|
"making attention of type 'vanilla' with 512 in_channels\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import argparse, os, sys, glob\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from omegaconf import OmegaConf\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"from tqdm import tqdm, trange\n",
|
||||||
|
"from itertools import islice\n",
|
||||||
|
"from einops import rearrange\n",
|
||||||
|
"from torchvision.utils import make_grid\n",
|
||||||
|
"import time\n",
|
||||||
|
"from pytorch_lightning import seed_everything\n",
|
||||||
|
"from torch import autocast\n",
|
||||||
|
"from contextlib import contextmanager, nullcontext\n",
|
||||||
|
"\n",
|
||||||
|
"from ldm.util import instantiate_from_config\n",
|
||||||
|
"from ldm.models.diffusion.ddim import DDIMSampler\n",
|
||||||
|
"from ldm.models.diffusion.plms import PLMSSampler\n",
|
||||||
|
"from scripts.image_variations import load_model_from_config\n",
|
||||||
|
"\n",
|
||||||
|
"device = \"cuda:0\"\n",
|
||||||
|
"\n",
|
||||||
|
"config = \"configs/stable-diffusion/sd-image-condition-finetune.yaml\"\n",
|
||||||
|
"ckpt = \"models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt\"\n",
|
||||||
|
"config = OmegaConf.load(config)\n",
|
||||||
|
"model = load_model_from_config(config, ckpt, device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "4181066b-1641-476f-aaca-ae49e6950dd2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"# use 0.15.4\n",
|
||||||
|
"def load_prior(model_path):\n",
|
||||||
|
" prior_network = DiffusionPriorNetwork(\n",
|
||||||
|
" dim=768,\n",
|
||||||
|
" depth=24,\n",
|
||||||
|
" dim_head=64,\n",
|
||||||
|
" heads=32,\n",
|
||||||
|
" normformer=True,\n",
|
||||||
|
" attn_dropout=5e-2,\n",
|
||||||
|
" ff_dropout=5e-2,\n",
|
||||||
|
" num_time_embeds=1,\n",
|
||||||
|
" num_image_embeds=1,\n",
|
||||||
|
" num_text_embeds=1,\n",
|
||||||
|
" num_timesteps=1000,\n",
|
||||||
|
" ff_mult=4\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" diffusion_prior = DiffusionPrior(\n",
|
||||||
|
" net=prior_network,\n",
|
||||||
|
" clip=OpenAIClipAdapter(\"ViT-L/14\"),\n",
|
||||||
|
" image_embed_dim=768,\n",
|
||||||
|
" timesteps=1000,\n",
|
||||||
|
" cond_drop_prob=0.1,\n",
|
||||||
|
" loss_type=\"l2\",\n",
|
||||||
|
" condition_on_text_encodings=True,\n",
|
||||||
|
" ).to(device)\n",
|
||||||
|
"\n",
|
||||||
|
" state_dict = torch.load(model_path, map_location='cpu')\n",
|
||||||
|
" if 'ema_model' in state_dict:\n",
|
||||||
|
" print('Loading EMA Model')\n",
|
||||||
|
" diffusion_prior.load_state_dict(state_dict['ema_model'], strict=True)\n",
|
||||||
|
" else:\n",
|
||||||
|
" print('Loading Standard Model')\n",
|
||||||
|
" diffusion_prior.load_state_dict(state_dict['model'], strict=False)\n",
|
||||||
|
" del state_dict\n",
|
||||||
|
" return diffusion_prior"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "ee81a270-2b7d-4d8c-8cf6-031e982597e2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from dalle2_pytorch.train_configs import DiffusionPriorConfig, TrainDiffusionPriorConfig\n",
|
||||||
|
"\n",
|
||||||
|
"def make_prior(\n",
|
||||||
|
" prior_config: DiffusionPriorConfig, checkpoint_path: str, device: str = None\n",
|
||||||
|
"):\n",
|
||||||
|
" # create model from config\n",
|
||||||
|
" diffusion_prior = prior_config.create()\n",
|
||||||
|
" state_dict = torch.load(checkpoint_path, map_location=\"cpu\")\n",
|
||||||
|
" diffusion_prior.load_state_dict(state_dict)\n",
|
||||||
|
" diffusion_prior.eval()\n",
|
||||||
|
" diffusion_prior.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
" if device == \"cpu\":\n",
|
||||||
|
" diffusion_prior.float()\n",
|
||||||
|
" return diffusion_prior\n",
|
||||||
|
"\n",
|
||||||
|
"# load entire config\n",
|
||||||
|
"train_config = TrainDiffusionPriorConfig.from_json_path(\"../DALLE2-pytorch/pretrained/prior_config.json\")\n",
|
||||||
|
"prior_config = train_config.prior\n",
|
||||||
|
"\n",
|
||||||
|
"# load model\n",
|
||||||
|
"prior = make_prior(prior_config=prior_config, checkpoint_path=\"../DALLE2-pytorch/pretrained/latest.pth\", device=device)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 22,
|
||||||
|
"id": "96d74e72-54d8-4529-a7c2-cfe5c0c8008e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "8e83104e232744f9952587e16c034668",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"sampling loop time step: 0%| | 0/64 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# tokenize the text\n",
|
||||||
|
"tokenized_text = clip.tokenize(\"still life in the style of Picasso\").to(device)\n",
|
||||||
|
"# predict an embedding, make sure to denormalise\n",
|
||||||
|
"predicted_embedding = prior.sample(tokenized_text, num_samples_per_batch=2, cond_scale=1.0)*prior.image_embed_scale"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"id": "7680379b-01bb-41a8-884d-4a10fa887ce0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"--------\n",
|
||||||
|
"[4, 64, 64]\n",
|
||||||
|
"Data shape for DDIM sampling is (4, 4, 64, 64), eta 1.0\n",
|
||||||
|
"Running DDIM Sampling with 50 timesteps\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"DDIM Sampler: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00, 2.79it/s]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"plms = False\n",
|
||||||
|
"outdir = \"dump\"\n",
|
||||||
|
"n_samples = 4\n",
|
||||||
|
"n_rows = 0\n",
|
||||||
|
"precision = \"fp32\"\n",
|
||||||
|
"\n",
|
||||||
|
"ddim_steps = 50\n",
|
||||||
|
"scale = 3.0\n",
|
||||||
|
"ddim_eta = 1.0\n",
|
||||||
|
"\n",
|
||||||
|
"if plms:\n",
|
||||||
|
" sampler = PLMSSampler(model)\n",
|
||||||
|
"else:\n",
|
||||||
|
" sampler = DDIMSampler(model)\n",
|
||||||
|
"\n",
|
||||||
|
"os.makedirs(outdir, exist_ok=True)\n",
|
||||||
|
"outpath = outdir\n",
|
||||||
|
"\n",
|
||||||
|
"batch_size = n_samples\n",
|
||||||
|
"n_rows = n_rows if n_rows > 0 else batch_size\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"sample_path = os.path.join(outpath, \"samples\")\n",
|
||||||
|
"os.makedirs(sample_path, exist_ok=True)\n",
|
||||||
|
"base_count = len(os.listdir(sample_path))\n",
|
||||||
|
"grid_count = len(os.listdir(outpath)) - 1\n",
|
||||||
|
"\n",
|
||||||
|
"start_code = None\n",
|
||||||
|
"\n",
|
||||||
|
"# c = torch.rand(n_samples, 1, 768, device=device)\n",
|
||||||
|
"c = predicted_embedding.tile(n_samples, 1).unsqueeze(1)\n",
|
||||||
|
"\n",
|
||||||
|
"precision_scope = autocast if precision==\"autocast\" else nullcontext\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" with precision_scope(\"cuda\"):\n",
|
||||||
|
" with model.ema_scope():\n",
|
||||||
|
" tic = time.time()\n",
|
||||||
|
" # c = model.get_learned_conditioning(prompts)\n",
|
||||||
|
"\n",
|
||||||
|
" uc = None\n",
|
||||||
|
" if scale != 1.0:\n",
|
||||||
|
" uc = torch.zeros_like(c)\n",
|
||||||
|
" shape = [4, 512 // 8, 512 // 8]\n",
|
||||||
|
" print(\"--------\")\n",
|
||||||
|
" print(shape)\n",
|
||||||
|
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n",
|
||||||
|
" conditioning=c,\n",
|
||||||
|
" batch_size=n_samples,\n",
|
||||||
|
" shape=shape,\n",
|
||||||
|
" verbose=False,\n",
|
||||||
|
" unconditional_guidance_scale=scale,\n",
|
||||||
|
" unconditional_conditioning=uc,\n",
|
||||||
|
" eta=ddim_eta,\n",
|
||||||
|
" x_T=start_code)\n",
|
||||||
|
"\n",
|
||||||
|
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
|
||||||
|
" x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
|
||||||
|
"\n",
|
||||||
|
" for x_sample in x_samples_ddim:\n",
|
||||||
|
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
|
||||||
|
" Image.fromarray(x_sample.astype(np.uint8)).save(\n",
|
||||||
|
" os.path.join(sample_path, f\"{base_count:05}.png\"))\n",
|
||||||
|
" base_count += 1\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "caef0607-dc6e-4862-99ed-15281c269a49",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3.10.4 ('stable')",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.4"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "7b7b6e55edb8d6b4ec26da3e41ac48d31f242b54c90f284dae7273709056fff2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
|
@ -19,7 +19,7 @@ from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, device, verbose=False):
|
def load_model_from_config(config, ckpt, device, verbose=False):
|
||||||
print(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location=device)
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
print(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
|
|
Loading…
Reference in a new issue