diff --git a/.gitignore b/.gitignore index 5af5443..c36f75c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ logs/ dump/ -examples/ +im-examples/ outputs/ flagged/ *.egg-info diff --git a/examples/prior_2_sd.ipynb b/examples/prior_2_sd.ipynb new file mode 100644 index 0000000..15f4d9f --- /dev/null +++ b/examples/prior_2_sd.ipynb @@ -0,0 +1,309 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "0533f618-f54c-4231-b79b-6fd3043696a0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jpinkney/miniconda3/envs/stable/lib/python3.10/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n", + " if not hasattr(tensorboard, \"__version__\") or LooseVersion(\n", + "/home/jpinkney/miniconda3/envs/stable/lib/python3.10/site-packages/torch/utils/tensorboard/__init__.py:6: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n", + " ) < LooseVersion(\"1.15\"):\n" + ] + } + ], + "source": [ + "from dalle2_pytorch.train_configs import DiffusionPriorConfig\n", + "import json\n", + "import torch\n", + "import torch\n", + "from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter\n", + "from dalle2_pytorch.trainer import DiffusionPriorTrainer\n", + "import clip" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "619dd2aa-4cdb-43bf-b7cd-349826330020", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading model from /home/jpinkney/code/stable-diffusion/logs/2022-08-19T15-53-26_jp_pretraining_img/checkpoints/last.ckpt\n", + "Global Step: 87056\n", + "LatentDiffusion: Running in eps-prediction mode\n", + "DiffusionWrapper has 859.52 M params.\n", + "Keeping EMAs of 688.\n", + "making attention of type 'vanilla' with 512 in_channels\n", + "Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n", + "making attention of type 'vanilla' with 512 in_channels\n" + ] + } + ], + "source": [ + "import argparse, os, sys, glob\n", + "import torch\n", + "import numpy as np\n", + "from omegaconf import OmegaConf\n", + "from PIL import Image\n", + "from tqdm import tqdm, trange\n", + "from itertools import islice\n", + "from einops import rearrange\n", + "from torchvision.utils import make_grid\n", + "import time\n", + "from pytorch_lightning import seed_everything\n", + "from torch import autocast\n", + "from contextlib import contextmanager, nullcontext\n", + "\n", + "from ldm.util import instantiate_from_config\n", + "from ldm.models.diffusion.ddim import DDIMSampler\n", + "from ldm.models.diffusion.plms import PLMSSampler\n", + "from scripts.image_variations import load_model_from_config\n", + "\n", + "device = \"cuda:0\"\n", + "\n", + "config = \"configs/stable-diffusion/sd-image-condition-finetune.yaml\"\n", + "ckpt = \"models/ldm/stable-diffusion-v1/sd-clip-vit-l14-img-embed_ema_only.ckpt\"\n", + "config = OmegaConf.load(config)\n", + "model = load_model_from_config(config, ckpt, device)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4181066b-1641-476f-aaca-ae49e6950dd2", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# use 0.15.4\n", + "def load_prior(model_path):\n", + " prior_network = DiffusionPriorNetwork(\n", + " dim=768,\n", + " depth=24,\n", + " dim_head=64,\n", + " heads=32,\n", + " normformer=True,\n", + " attn_dropout=5e-2,\n", + " ff_dropout=5e-2,\n", + " num_time_embeds=1,\n", + " num_image_embeds=1,\n", + " num_text_embeds=1,\n", + " num_timesteps=1000,\n", + " ff_mult=4\n", + " )\n", + "\n", + " diffusion_prior = DiffusionPrior(\n", + " net=prior_network,\n", + " clip=OpenAIClipAdapter(\"ViT-L/14\"),\n", + " image_embed_dim=768,\n", + " timesteps=1000,\n", + " cond_drop_prob=0.1,\n", + " loss_type=\"l2\",\n", + " condition_on_text_encodings=True,\n", + " ).to(device)\n", + "\n", + " state_dict = torch.load(model_path, map_location='cpu')\n", + " if 'ema_model' in state_dict:\n", + " print('Loading EMA Model')\n", + " diffusion_prior.load_state_dict(state_dict['ema_model'], strict=True)\n", + " else:\n", + " print('Loading Standard Model')\n", + " diffusion_prior.load_state_dict(state_dict['model'], strict=False)\n", + " del state_dict\n", + " return diffusion_prior" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee81a270-2b7d-4d8c-8cf6-031e982597e2", + "metadata": {}, + "outputs": [], + "source": [ + "from dalle2_pytorch.train_configs import DiffusionPriorConfig, TrainDiffusionPriorConfig\n", + "\n", + "def make_prior(\n", + " prior_config: DiffusionPriorConfig, checkpoint_path: str, device: str = None\n", + "):\n", + " # create model from config\n", + " diffusion_prior = prior_config.create()\n", + " state_dict = torch.load(checkpoint_path, map_location=\"cpu\")\n", + " diffusion_prior.load_state_dict(state_dict)\n", + " diffusion_prior.eval()\n", + " diffusion_prior.to(device)\n", + "\n", + " if device == \"cpu\":\n", + " diffusion_prior.float()\n", + " return diffusion_prior\n", + "\n", + "# load entire config\n", + "train_config = TrainDiffusionPriorConfig.from_json_path(\"../DALLE2-pytorch/pretrained/prior_config.json\")\n", + "prior_config = train_config.prior\n", + "\n", + "# load model\n", + "prior = make_prior(prior_config=prior_config, checkpoint_path=\"../DALLE2-pytorch/pretrained/latest.pth\", device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "96d74e72-54d8-4529-a7c2-cfe5c0c8008e", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8e83104e232744f9952587e16c034668", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "sampling loop time step: 0%| | 0/64 [00:00 0 else batch_size\n", + "\n", + "\n", + "sample_path = os.path.join(outpath, \"samples\")\n", + "os.makedirs(sample_path, exist_ok=True)\n", + "base_count = len(os.listdir(sample_path))\n", + "grid_count = len(os.listdir(outpath)) - 1\n", + "\n", + "start_code = None\n", + "\n", + "# c = torch.rand(n_samples, 1, 768, device=device)\n", + "c = predicted_embedding.tile(n_samples, 1).unsqueeze(1)\n", + "\n", + "precision_scope = autocast if precision==\"autocast\" else nullcontext\n", + "with torch.no_grad():\n", + " with precision_scope(\"cuda\"):\n", + " with model.ema_scope():\n", + " tic = time.time()\n", + " # c = model.get_learned_conditioning(prompts)\n", + "\n", + " uc = None\n", + " if scale != 1.0:\n", + " uc = torch.zeros_like(c)\n", + " shape = [4, 512 // 8, 512 // 8]\n", + " print(\"--------\")\n", + " print(shape)\n", + " samples_ddim, _ = sampler.sample(S=ddim_steps,\n", + " conditioning=c,\n", + " batch_size=n_samples,\n", + " shape=shape,\n", + " verbose=False,\n", + " unconditional_guidance_scale=scale,\n", + " unconditional_conditioning=uc,\n", + " eta=ddim_eta,\n", + " x_T=start_code)\n", + "\n", + " x_samples_ddim = model.decode_first_stage(samples_ddim)\n", + " x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n", + "\n", + " for x_sample in x_samples_ddim:\n", + " x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n", + " Image.fromarray(x_sample.astype(np.uint8)).save(\n", + " os.path.join(sample_path, f\"{base_count:05}.png\"))\n", + " base_count += 1\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "caef0607-dc6e-4862-99ed-15281c269a49", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.4 ('stable')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "vscode": { + "interpreter": { + "hash": "7b7b6e55edb8d6b4ec26da3e41ac48d31f242b54c90f284dae7273709056fff2" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/image_variations.py b/scripts/image_variations.py index 16959a0..28bcee4 100644 --- a/scripts/image_variations.py +++ b/scripts/image_variations.py @@ -19,7 +19,7 @@ from ldm.util import instantiate_from_config def load_model_from_config(config, ckpt, device, verbose=False): print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") + pl_sd = torch.load(ckpt, map_location=device) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"]