add prior to sd notebook
This commit is contained in:
parent
12dd21670b
commit
c7504a6ec8
3 changed files with 311 additions and 2 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,6 +1,6 @@
|
|||
logs/
|
||||
dump/
|
||||
examples/
|
||||
im-examples/
|
||||
outputs/
|
||||
flagged/
|
||||
*.egg-info
|
||||
|
|
309
examples/prior_2_sd.ipynb
Normal file
309
examples/prior_2_sd.ipynb
Normal file
|
@ -0,0 +1,309 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "0533f618-f54c-4231-b79b-6fd3043696a0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/jpinkney/miniconda3/envs/stable/lib/python3.10/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
|
||||
" if not hasattr(tensorboard, \"__version__\") or LooseVersion(\n",
|
||||
"/home/jpinkney/miniconda3/envs/stable/lib/python3.10/site-packages/torch/utils/tensorboard/__init__.py:6: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
|
||||
" ) < LooseVersion(\"1.15\"):\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from dalle2_pytorch.train_configs import DiffusionPriorConfig\n",
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"import torch\n",
|
||||
"from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter\n",
|
||||
"from dalle2_pytorch.trainer import DiffusionPriorTrainer\n",
|
||||
"import clip"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "619dd2aa-4cdb-43bf-b7cd-349826330020",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loading model from /home/jpinkney/code/stable-diffusion/logs/2022-08-19T15-53-26_jp_pretraining_img/checkpoints/last.ckpt\n",
|
||||
"Global Step: 87056\n",
|
||||
"LatentDiffusion: Running in eps-prediction mode\n",
|
||||
"DiffusionWrapper has 859.52 M params.\n",
|
||||
"Keeping EMAs of 688.\n",
|
||||
"making attention of type 'vanilla' with 512 in_channels\n",
|
||||
"Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
|
||||
"making attention of type 'vanilla' with 512 in_channels\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import argparse, os, sys, glob\n",
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"from omegaconf import OmegaConf\n",
|
||||
"from PIL import Image\n",
|
||||
"from tqdm import tqdm, trange\n",
|
||||
"from itertools import islice\n",
|
||||
"from einops import rearrange\n",
|
||||
"from torchvision.utils import make_grid\n",
|
||||
"import time\n",
|
||||
"from pytorch_lightning import seed_everything\n",
|
||||
"from torch import autocast\n",
|
||||
"from contextlib import contextmanager, nullcontext\n",
|
||||
"\n",
|
||||
"from ldm.util import instantiate_from_config\n",
|
||||
"from ldm.models.diffusion.ddim import DDIMSampler\n",
|
||||
"from ldm.models.diffusion.plms import PLMSSampler\n",
|
||||
"from scripts.image_variations import load_model_from_config\n",
|
||||
"\n",
|
||||
"device = \"cuda:0\"\n",
|
||||
"\n",
|
||||
"config = \"configs/stable-diffusion/sd-image-condition-finetune.yaml\"\n",
|
||||
"ckpt = \"models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt\"\n",
|
||||
"config = OmegaConf.load(config)\n",
|
||||
"model = load_model_from_config(config, ckpt, device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "4181066b-1641-476f-aaca-ae49e6950dd2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"# use 0.15.4\n",
|
||||
"def load_prior(model_path):\n",
|
||||
" prior_network = DiffusionPriorNetwork(\n",
|
||||
" dim=768,\n",
|
||||
" depth=24,\n",
|
||||
" dim_head=64,\n",
|
||||
" heads=32,\n",
|
||||
" normformer=True,\n",
|
||||
" attn_dropout=5e-2,\n",
|
||||
" ff_dropout=5e-2,\n",
|
||||
" num_time_embeds=1,\n",
|
||||
" num_image_embeds=1,\n",
|
||||
" num_text_embeds=1,\n",
|
||||
" num_timesteps=1000,\n",
|
||||
" ff_mult=4\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" diffusion_prior = DiffusionPrior(\n",
|
||||
" net=prior_network,\n",
|
||||
" clip=OpenAIClipAdapter(\"ViT-L/14\"),\n",
|
||||
" image_embed_dim=768,\n",
|
||||
" timesteps=1000,\n",
|
||||
" cond_drop_prob=0.1,\n",
|
||||
" loss_type=\"l2\",\n",
|
||||
" condition_on_text_encodings=True,\n",
|
||||
" ).to(device)\n",
|
||||
"\n",
|
||||
" state_dict = torch.load(model_path, map_location='cpu')\n",
|
||||
" if 'ema_model' in state_dict:\n",
|
||||
" print('Loading EMA Model')\n",
|
||||
" diffusion_prior.load_state_dict(state_dict['ema_model'], strict=True)\n",
|
||||
" else:\n",
|
||||
" print('Loading Standard Model')\n",
|
||||
" diffusion_prior.load_state_dict(state_dict['model'], strict=False)\n",
|
||||
" del state_dict\n",
|
||||
" return diffusion_prior"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ee81a270-2b7d-4d8c-8cf6-031e982597e2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dalle2_pytorch.train_configs import DiffusionPriorConfig, TrainDiffusionPriorConfig\n",
|
||||
"\n",
|
||||
"def make_prior(\n",
|
||||
" prior_config: DiffusionPriorConfig, checkpoint_path: str, device: str = None\n",
|
||||
"):\n",
|
||||
" # create model from config\n",
|
||||
" diffusion_prior = prior_config.create()\n",
|
||||
" state_dict = torch.load(checkpoint_path, map_location=\"cpu\")\n",
|
||||
" diffusion_prior.load_state_dict(state_dict)\n",
|
||||
" diffusion_prior.eval()\n",
|
||||
" diffusion_prior.to(device)\n",
|
||||
"\n",
|
||||
" if device == \"cpu\":\n",
|
||||
" diffusion_prior.float()\n",
|
||||
" return diffusion_prior\n",
|
||||
"\n",
|
||||
"# load entire config\n",
|
||||
"train_config = TrainDiffusionPriorConfig.from_json_path(\"../DALLE2-pytorch/pretrained/prior_config.json\")\n",
|
||||
"prior_config = train_config.prior\n",
|
||||
"\n",
|
||||
"# load model\n",
|
||||
"prior = make_prior(prior_config=prior_config, checkpoint_path=\"../DALLE2-pytorch/pretrained/latest.pth\", device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"id": "96d74e72-54d8-4529-a7c2-cfe5c0c8008e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "8e83104e232744f9952587e16c034668",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"sampling loop time step: 0%| | 0/64 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# tokenize the text\n",
|
||||
"tokenized_text = clip.tokenize(\"still life in the style of Picasso\").to(device)\n",
|
||||
"# predict an embedding, make sure to denormalise\n",
|
||||
"predicted_embedding = prior.sample(tokenized_text, num_samples_per_batch=2, cond_scale=1.0)*prior.image_embed_scale"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "7680379b-01bb-41a8-884d-4a10fa887ce0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"--------\n",
|
||||
"[4, 64, 64]\n",
|
||||
"Data shape for DDIM sampling is (4, 4, 64, 64), eta 1.0\n",
|
||||
"Running DDIM Sampling with 50 timesteps\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"DDIM Sampler: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:17<00:00, 2.79it/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"plms = False\n",
|
||||
"outdir = \"dump\"\n",
|
||||
"n_samples = 4\n",
|
||||
"n_rows = 0\n",
|
||||
"precision = \"fp32\"\n",
|
||||
"\n",
|
||||
"ddim_steps = 50\n",
|
||||
"scale = 3.0\n",
|
||||
"ddim_eta = 1.0\n",
|
||||
"\n",
|
||||
"if plms:\n",
|
||||
" sampler = PLMSSampler(model)\n",
|
||||
"else:\n",
|
||||
" sampler = DDIMSampler(model)\n",
|
||||
"\n",
|
||||
"os.makedirs(outdir, exist_ok=True)\n",
|
||||
"outpath = outdir\n",
|
||||
"\n",
|
||||
"batch_size = n_samples\n",
|
||||
"n_rows = n_rows if n_rows > 0 else batch_size\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"sample_path = os.path.join(outpath, \"samples\")\n",
|
||||
"os.makedirs(sample_path, exist_ok=True)\n",
|
||||
"base_count = len(os.listdir(sample_path))\n",
|
||||
"grid_count = len(os.listdir(outpath)) - 1\n",
|
||||
"\n",
|
||||
"start_code = None\n",
|
||||
"\n",
|
||||
"# c = torch.rand(n_samples, 1, 768, device=device)\n",
|
||||
"c = predicted_embedding.tile(n_samples, 1).unsqueeze(1)\n",
|
||||
"\n",
|
||||
"precision_scope = autocast if precision==\"autocast\" else nullcontext\n",
|
||||
"with torch.no_grad():\n",
|
||||
" with precision_scope(\"cuda\"):\n",
|
||||
" with model.ema_scope():\n",
|
||||
" tic = time.time()\n",
|
||||
" # c = model.get_learned_conditioning(prompts)\n",
|
||||
"\n",
|
||||
" uc = None\n",
|
||||
" if scale != 1.0:\n",
|
||||
" uc = torch.zeros_like(c)\n",
|
||||
" shape = [4, 512 // 8, 512 // 8]\n",
|
||||
" print(\"--------\")\n",
|
||||
" print(shape)\n",
|
||||
" samples_ddim, _ = sampler.sample(S=ddim_steps,\n",
|
||||
" conditioning=c,\n",
|
||||
" batch_size=n_samples,\n",
|
||||
" shape=shape,\n",
|
||||
" verbose=False,\n",
|
||||
" unconditional_guidance_scale=scale,\n",
|
||||
" unconditional_conditioning=uc,\n",
|
||||
" eta=ddim_eta,\n",
|
||||
" x_T=start_code)\n",
|
||||
"\n",
|
||||
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
|
||||
" x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
|
||||
"\n",
|
||||
" for x_sample in x_samples_ddim:\n",
|
||||
" x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n",
|
||||
" Image.fromarray(x_sample.astype(np.uint8)).save(\n",
|
||||
" os.path.join(sample_path, f\"{base_count:05}.png\"))\n",
|
||||
" base_count += 1\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "caef0607-dc6e-4862-99ed-15281c269a49",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.10.4 ('stable')",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.4"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "7b7b6e55edb8d6b4ec26da3e41ac48d31f242b54c90f284dae7273709056fff2"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -19,7 +19,7 @@ from ldm.util import instantiate_from_config
|
|||
|
||||
def load_model_from_config(config, ckpt, device, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
pl_sd = torch.load(ckpt, map_location=device)
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
|
|
Loading…
Reference in a new issue