stable-diffusion-finetune/tpde_finetune.ipynb

59 KiB

Fine Tune Stable Diffusion

Fine tuning Stable Diffusion on Pokemon, for more details see the Lambda Labs examples repo.

We recommend using a multi-GPU machine, for example an instance from Lambda GPU Cloud. If running on Colab this notebook is likely to need a GPU with >16GB of VRAM and a runtime with high RAM, which will almost certainly need Colab Pro or Pro+. (If you get errors suchs as Killed or CUDA out of memory then one of these is not sufficient)

In [2]:
!git clone https://github.com/justinpinkney/stable-diffusion.git
%cd stable-diffusion
!pip install --upgrade pip
!pip install -r requirements.txt

!pip install --upgrade keras # on lambda stack we need to upgrade keras
#!pip uninstall -y torchtext # on colab we need to remove torchtext
Cloning into 'stable-diffusion'...
remote: Enumerating objects: 1546, done.
remote: Counting objects: 100% (494/494), done.
remote: Compressing objects: 100% (69/69), done.
remote: Total 1546 (delta 458), reused 425 (delta 425), pack-reused 1052
Receiving objects: 100% (1546/1546), 67.92 MiB | 69.07 MiB/s, done.
Resolving deltas: 100% (980/980), done.
/home/ubuntu/stable-diffusion
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: pip in /home/ubuntu/.local/lib/python3.8/site-packages (22.2.2)
Defaulting to user installation because normal site-packages is not writeable
Obtaining taming-transformers from git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers (from -r requirements.txt (line 20))
  Cloning https://github.com/CompVis/taming-transformers.git (to revision master) to ./src/taming-transformers
  Running command git clone --filter=blob:none --quiet https://github.com/CompVis/taming-transformers.git /home/ubuntu/stable-diffusion/src/taming-transformers
  Resolved https://github.com/CompVis/taming-transformers.git to commit 24268930bf1dce879235a7fddd0b2355b84d7ea6
  Preparing metadata (setup.py) ... done
Obtaining clip from git+https://github.com/openai/CLIP.git@main#egg=clip (from -r requirements.txt (line 21))
  Cloning https://github.com/openai/CLIP.git (to revision main) to ./src/clip
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /home/ubuntu/stable-diffusion/src/clip
  Resolved https://github.com/openai/CLIP.git to commit d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
  Preparing metadata (setup.py) ... done
Obtaining file:///home/ubuntu/stable-diffusion (from -r requirements.txt (line 22))
  Preparing metadata (setup.py) ... done
Collecting albumentations==0.4.3
  Downloading albumentations-0.4.3.tar.gz (3.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.2/3.2 MB 142.2 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Collecting opencv-python==4.5.5.64
  Downloading opencv_python-4.5.5.64-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (60.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 60.5/60.5 MB 69.0 MB/s eta 0:00:0000:0100:01
Collecting pudb==2019.2
  Downloading pudb-2019.2.tar.gz (59 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.5/59.5 kB 21.1 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Collecting imageio==2.9.0
  Downloading imageio-2.9.0-py3-none-any.whl (3.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.3/3.3 MB 150.4 MB/s eta 0:00:00
Collecting imageio-ffmpeg==0.4.2
  Downloading imageio_ffmpeg-0.4.2-py3-none-manylinux2010_x86_64.whl (26.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 26.9/26.9 MB 112.5 MB/s eta 0:00:0000:0100:01
Collecting pytorch-lightning==1.4.2
  Downloading pytorch_lightning-1.4.2-py3-none-any.whl (916 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 916.6/916.6 kB 129.2 MB/s eta 0:00:00
Collecting omegaconf==2.1.1
  Downloading omegaconf-2.1.1-py3-none-any.whl (74 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 74.7/74.7 kB 28.0 MB/s eta 0:00:00
Collecting test-tube>=0.7.5
  Downloading test_tube-0.7.5.tar.gz (21 kB)
  Preparing metadata (setup.py) ... done
Collecting streamlit>=0.73.1
  Downloading streamlit-1.13.0-py2.py3-none-any.whl (9.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.2/9.2 MB 157.7 MB/s eta 0:00:00a 0:00:01
Collecting einops==0.3.0
  Downloading einops-0.3.0-py2.py3-none-any.whl (25 kB)
Collecting torch-fidelity==0.3.0
  Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Collecting transformers
  Downloading transformers-4.22.2-py3-none-any.whl (4.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.9/4.9 MB 175.0 MB/s eta 0:00:00
Collecting kornia==0.6
  Downloading kornia-0.6.0-py2.py3-none-any.whl (367 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 367.1/367.1 kB 77.5 MB/s eta 0:00:00
Collecting webdataset==0.2.5
  Downloading webdataset-0.2.5-py3-none-any.whl (46 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.9/46.9 kB 17.3 MB/s eta 0:00:00
Collecting torchmetrics==0.6.0
  Downloading torchmetrics-0.6.0-py3-none-any.whl (329 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 329.4/329.4 kB 85.2 MB/s eta 0:00:00
Collecting fire==0.4.0
  Downloading fire-0.4.0.tar.gz (87 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 87.7/87.7 kB 33.1 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Collecting gradio==3.1.4
  Downloading gradio-3.1.4-py3-none-any.whl (5.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.6/5.6 MB 153.3 MB/s eta 0:00:0000:01
Collecting diffusers==0.3.0
  Downloading diffusers-0.3.0-py3-none-any.whl (153 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 153.9/153.9 kB 53.9 MB/s eta 0:00:00
Collecting datasets[vision]==2.4.0
  Downloading datasets-2.4.0-py3-none-any.whl (365 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 365.7/365.7 kB 97.4 MB/s eta 0:00:00
Requirement already satisfied: PyYAML in /usr/lib/python3/dist-packages (from albumentations==0.4.3->-r requirements.txt (line 1)) (5.3.1)
Collecting imgaug<0.2.7,>=0.2.5
  Downloading imgaug-0.2.6.tar.gz (631 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 631.4/631.4 kB 116.0 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Requirement already satisfied: numpy>=1.11.1 in /home/ubuntu/.local/lib/python3.8/site-packages (from albumentations==0.4.3->-r requirements.txt (line 1)) (1.23.2)
Collecting opencv-python-headless>=4.1.1
  Downloading opencv_python_headless-4.6.0.66-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (48.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 48.3/48.3 MB 80.0 MB/s eta 0:00:0000:0100:01
Requirement already satisfied: scipy in /home/ubuntu/.local/lib/python3.8/site-packages (from albumentations==0.4.3->-r requirements.txt (line 1)) (1.9.1)
Requirement already satisfied: pygments>=1.0 in /home/ubuntu/.local/lib/python3.8/site-packages (from pudb==2019.2->-r requirements.txt (line 3)) (2.13.0)
Collecting urwid>=1.1.1
  Downloading urwid-2.1.2.tar.gz (634 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 634.6/634.6 kB 111.8 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Requirement already satisfied: pillow in /usr/lib/python3/dist-packages (from imageio==2.9.0->-r requirements.txt (line 4)) (7.0.0)
Requirement already satisfied: tqdm>=4.41.0 in /home/ubuntu/.local/lib/python3.8/site-packages (from pytorch-lightning==1.4.2->-r requirements.txt (line 6)) (4.64.1)
Requirement already satisfied: future>=0.17.1 in /usr/lib/python3/dist-packages (from pytorch-lightning==1.4.2->-r requirements.txt (line 6)) (0.18.2)
Requirement already satisfied: tensorboard>=2.2.0 in /usr/lib/python3/dist-packages (from pytorch-lightning==1.4.2->-r requirements.txt (line 6)) (2.9.1)
Requirement already satisfied: torch>=1.6 in /usr/lib/python3/dist-packages (from pytorch-lightning==1.4.2->-r requirements.txt (line 6)) (1.11.0)
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.8.2-py3-none-any.whl (140 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 140.8/140.8 kB 49.4 MB/s eta 0:00:00
Requirement already satisfied: typing-extensions in /home/ubuntu/.local/lib/python3.8/site-packages (from pytorch-lightning==1.4.2->-r requirements.txt (line 6)) (4.3.0)
Collecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Requirement already satisfied: packaging>=17.0 in /home/ubuntu/.local/lib/python3.8/site-packages (from pytorch-lightning==1.4.2->-r requirements.txt (line 6)) (21.3)
Collecting antlr4-python3-runtime==4.8
  Downloading antlr4-python3-runtime-4.8.tar.gz (112 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 112.4/112.4 kB 42.4 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Requirement already satisfied: torchvision in /usr/lib/python3/dist-packages (from torch-fidelity==0.3.0->-r requirements.txt (line 11)) (0.12.0)
Collecting braceexpand
  Downloading braceexpand-0.1.7-py2.py3-none-any.whl (5.9 kB)
Requirement already satisfied: six in /usr/lib/python3/dist-packages (from fire==0.4.0->-r requirements.txt (line 16)) (1.14.0)
Requirement already satisfied: termcolor in /usr/lib/python3/dist-packages (from fire==0.4.0->-r requirements.txt (line 16)) (1.1.0)
Collecting uvicorn
  Downloading uvicorn-0.18.3-py3-none-any.whl (57 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 57.4/57.4 kB 21.2 MB/s eta 0:00:00
Collecting fastapi
  Downloading fastapi-0.85.0-py3-none-any.whl (55 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 55.3/55.3 kB 18.8 MB/s eta 0:00:00
Collecting httpx
  Downloading httpx-0.23.0-py3-none-any.whl (84 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.8/84.8 kB 33.0 MB/s eta 0:00:00
Collecting python-multipart
  Downloading python-multipart-0.0.5.tar.gz (32 kB)
  Preparing metadata (setup.py) ... done
Collecting pycryptodome
  Downloading pycryptodome-3.15.0-cp35-abi3-manylinux2010_x86_64.whl (2.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.3/2.3 MB 146.6 MB/s eta 0:00:00
Collecting paramiko
  Downloading paramiko-2.11.0-py2.py3-none-any.whl (212 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 212.9/212.9 kB 66.8 MB/s eta 0:00:00
Collecting ffmpy
  Downloading ffmpy-0.3.0.tar.gz (4.8 kB)
  Preparing metadata (setup.py) ... done
Collecting analytics-python
  Downloading analytics_python-1.4.0-py2.py3-none-any.whl (15 kB)
Collecting markdown-it-py[linkify,plugins]
  Downloading markdown_it_py-2.1.0-py3-none-any.whl (84 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.5/84.5 kB 29.1 MB/s eta 0:00:00
Requirement already satisfied: pandas in /home/ubuntu/.local/lib/python3.8/site-packages (from gradio==3.1.4->-r requirements.txt (line 17)) (1.4.4)
Collecting pydub
  Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
Requirement already satisfied: Jinja2 in /home/ubuntu/.local/lib/python3.8/site-packages (from gradio==3.1.4->-r requirements.txt (line 17)) (3.1.2)
Requirement already satisfied: fsspec in /usr/lib/python3/dist-packages (from gradio==3.1.4->-r requirements.txt (line 17)) (0.6.1)
Collecting orjson
  Downloading orjson-3.8.0-cp38-cp38-manylinux_2_28_x86_64.whl (145 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 145.9/145.9 kB 48.8 MB/s eta 0:00:00
Collecting aiohttp
  Downloading aiohttp-3.8.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 126.8 MB/s eta 0:00:00
Requirement already satisfied: matplotlib in /home/ubuntu/.local/lib/python3.8/site-packages (from gradio==3.1.4->-r requirements.txt (line 17)) (3.5.3)
Requirement already satisfied: requests in /home/ubuntu/.local/lib/python3.8/site-packages (from gradio==3.1.4->-r requirements.txt (line 17)) (2.28.1)
Requirement already satisfied: pydantic in /home/ubuntu/.local/lib/python3.8/site-packages (from gradio==3.1.4->-r requirements.txt (line 17)) (1.9.2)
Collecting h11<0.13,>=0.11
  Downloading h11-0.12.0-py3-none-any.whl (54 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.9/54.9 kB 20.6 MB/s eta 0:00:00
Collecting huggingface-hub>=0.8.1
  Downloading huggingface_hub-0.10.0-py3-none-any.whl (163 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 163.5/163.5 kB 53.0 MB/s eta 0:00:00
Requirement already satisfied: importlib-metadata in /home/ubuntu/.local/lib/python3.8/site-packages (from diffusers==0.3.0->-r requirements.txt (line 18)) (4.12.0)
Requirement already satisfied: filelock in /usr/lib/python3/dist-packages (from diffusers==0.3.0->-r requirements.txt (line 18)) (3.0.12)
Collecting regex!=2019.12.17
  Downloading regex-2022.9.13-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (772 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 772.3/772.3 kB 115.2 MB/s eta 0:00:00
Collecting dill<0.3.6
  Downloading dill-0.3.5.1-py2.py3-none-any.whl (95 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 95.8/95.8 kB 36.2 MB/s eta 0:00:00
Collecting xxhash
  Downloading xxhash-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 212.1/212.1 kB 64.3 MB/s eta 0:00:00
Collecting pyarrow>=6.0.0
  Downloading pyarrow-9.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (35.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35.3/35.3 MB 95.1 MB/s eta 0:00:0000:0100:01
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting multiprocess
  Downloading multiprocess-0.70.13-py38-none-any.whl (131 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 131.4/131.4 kB 46.2 MB/s eta 0:00:00
Collecting semver
  Downloading semver-2.13.0-py2.py3-none-any.whl (12 kB)
Collecting tzlocal>=1.1
  Downloading tzlocal-4.2-py3-none-any.whl (19 kB)
Collecting pydeck>=0.1.dev5
  Downloading pydeck-0.8.0b3-py2.py3-none-any.whl (4.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.7/4.7 MB 165.7 MB/s eta 0:00:00
Collecting altair>=3.2.0
  Downloading altair-4.2.0-py3-none-any.whl (812 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 812.8/812.8 kB 120.3 MB/s eta 0:00:00
Collecting toml
  Downloading toml-0.10.2-py2.py3-none-any.whl (16 kB)
Requirement already satisfied: python-dateutil in /home/ubuntu/.local/lib/python3.8/site-packages (from streamlit>=0.73.1->-r requirements.txt (line 9)) (2.8.2)
Collecting gitpython!=3.1.19
  Downloading GitPython-3.1.28-py3-none-any.whl (182 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 182.5/182.5 kB 56.6 MB/s eta 0:00:00
Collecting pympler>=0.9
  Downloading Pympler-1.0.1-py3-none-any.whl (164 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 164.8/164.8 kB 52.8 MB/s eta 0:00:00
Requirement already satisfied: tornado>=5.0 in /home/ubuntu/.local/lib/python3.8/site-packages (from streamlit>=0.73.1->-r requirements.txt (line 9)) (6.2)
Requirement already satisfied: click>=7.0 in /usr/lib/python3/dist-packages (from streamlit>=0.73.1->-r requirements.txt (line 9)) (7.0)
Collecting watchdog
  Downloading watchdog-2.1.9-py3-none-manylinux2014_x86_64.whl (78 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 78.4/78.4 kB 30.0 MB/s eta 0:00:00
Requirement already satisfied: cachetools>=4.0 in /usr/lib/python3/dist-packages (from streamlit>=0.73.1->-r requirements.txt (line 9)) (4.0.0)
Collecting rich>=10.11.0
  Downloading rich-12.6.0-py3-none-any.whl (237 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 237.5/237.5 kB 69.8 MB/s eta 0:00:00
Collecting validators>=0.2
  Downloading validators-0.20.0.tar.gz (30 kB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: blinker>=1.0.0 in /usr/lib/python3/dist-packages (from streamlit>=0.73.1->-r requirements.txt (line 9)) (1.4)
Collecting protobuf!=3.20.2,<4,>=3.12
  Downloading protobuf-3.20.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 127.6 MB/s eta 0:00:00
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.6/6.6 MB 161.7 MB/s eta 0:00:0000:01
Collecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 53.1/53.1 kB 20.6 MB/s eta 0:00:00
Requirement already satisfied: toolz in /usr/lib/python3/dist-packages (from altair>=3.2.0->streamlit>=0.73.1->-r requirements.txt (line 9)) (0.9.0)
Requirement already satisfied: jsonschema>=3.0 in /usr/lib/python3/dist-packages (from altair>=3.2.0->streamlit>=0.73.1->-r requirements.txt (line 9)) (3.2.0)
Requirement already satisfied: entrypoints in /usr/lib/python3/dist-packages (from altair>=3.2.0->streamlit>=0.73.1->-r requirements.txt (line 9)) (0.3)
Collecting async-timeout<5.0,>=4.0.0a3
  Downloading async_timeout-4.0.2-py3-none-any.whl (5.8 kB)
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.8.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (262 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 262.1/262.1 kB 74.5 MB/s eta 0:00:00
Collecting frozenlist>=1.1.1
  Downloading frozenlist-1.3.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (161 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 161.3/161.3 kB 57.7 MB/s eta 0:00:00
Requirement already satisfied: attrs>=17.3.0 in /usr/lib/python3/dist-packages (from aiohttp->gradio==3.1.4->-r requirements.txt (line 17)) (19.3.0)
Collecting aiosignal>=1.1.2
  Downloading aiosignal-1.2.0-py3-none-any.whl (8.2 kB)
Collecting multidict<7.0,>=4.5
  Downloading multidict-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (121 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.3/121.3 kB 41.8 MB/s eta 0:00:00
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /home/ubuntu/.local/lib/python3.8/site-packages (from aiohttp->gradio==3.1.4->-r requirements.txt (line 17)) (2.1.1)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.1/63.1 kB 25.0 MB/s eta 0:00:00
Requirement already satisfied: scikit-image>=0.11.0 in /usr/lib/python3/dist-packages (from imgaug<0.2.7,>=0.2.5->albumentations==0.4.3->-r requirements.txt (line 1)) (0.16.2)
Requirement already satisfied: zipp>=0.5 in /usr/lib/python3/dist-packages (from importlib-metadata->diffusers==0.3.0->-r requirements.txt (line 18)) (1.0.0)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/lib/python3/dist-packages (from packaging>=17.0->pytorch-lightning==1.4.2->-r requirements.txt (line 6)) (2.4.6)
Requirement already satisfied: pytz>=2020.1 in /home/ubuntu/.local/lib/python3.8/site-packages (from pandas->gradio==3.1.4->-r requirements.txt (line 17)) (2022.2.1)
Requirement already satisfied: MarkupSafe>=2.0 in /home/ubuntu/.local/lib/python3.8/site-packages (from Jinja2->gradio==3.1.4->-r requirements.txt (line 17)) (2.1.1)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/lib/python3/dist-packages (from requests->gradio==3.1.4->-r requirements.txt (line 17)) (1.25.8)
Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests->gradio==3.1.4->-r requirements.txt (line 17)) (2019.11.28)
Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests->gradio==3.1.4->-r requirements.txt (line 17)) (2.8)
Collecting urllib3<1.27,>=1.21.1
  Downloading urllib3-1.26.12-py2.py3-none-any.whl (140 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 140.4/140.4 kB 52.0 MB/s eta 0:00:00
Collecting commonmark<0.10.0,>=0.9.0
  Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 51.1/51.1 kB 20.0 MB/s eta 0:00:00
Collecting pytz-deprecation-shim
  Downloading pytz_deprecation_shim-0.1.0.post0-py2.py3-none-any.whl (15 kB)
Collecting backports.zoneinfo
  Downloading backports.zoneinfo-0.2.1-cp38-cp38-manylinux1_x86_64.whl (74 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 74.0/74.0 kB 29.5 MB/s eta 0:00:00
Requirement already satisfied: decorator>=3.4.0 in /usr/lib/python3/dist-packages (from validators>=0.2->streamlit>=0.73.1->-r requirements.txt (line 9)) (4.4.2)
Collecting monotonic>=1.5
  Downloading monotonic-1.6-py2.py3-none-any.whl (8.2 kB)
Collecting backoff==1.10.0
  Downloading backoff-1.10.0-py2.py3-none-any.whl (31 kB)
Collecting starlette==0.20.4
  Downloading starlette-0.20.4-py3-none-any.whl (63 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.6/63.6 kB 25.1 MB/s eta 0:00:00
Requirement already satisfied: anyio<5,>=3.4.0 in /home/ubuntu/.local/lib/python3.8/site-packages (from starlette==0.20.4->fastapi->gradio==3.1.4->-r requirements.txt (line 17)) (3.6.1)
Collecting wcwidth>=0.2.5
  Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)
Requirement already satisfied: sniffio in /home/ubuntu/.local/lib/python3.8/site-packages (from httpx->gradio==3.1.4->-r requirements.txt (line 17)) (1.3.0)
Collecting httpcore<0.16.0,>=0.15.0
  Downloading httpcore-0.15.0-py3-none-any.whl (68 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 68.4/68.4 kB 27.0 MB/s eta 0:00:00
Collecting rfc3986[idna2008]<2,>=1.3
  Downloading rfc3986-1.5.0-py2.py3-none-any.whl (31 kB)
Collecting mdurl~=0.1
  Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)
Collecting linkify-it-py~=1.0
  Downloading linkify_it_py-1.0.3-py3-none-any.whl (19 kB)
Collecting mdit-py-plugins
  Downloading mdit_py_plugins-0.3.1-py3-none-any.whl (46 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 46.5/46.5 kB 21.9 MB/s eta 0:00:00
Requirement already satisfied: fonttools>=4.22.0 in /home/ubuntu/.local/lib/python3.8/site-packages (from matplotlib->gradio==3.1.4->-r requirements.txt (line 17)) (4.37.1)
Requirement already satisfied: cycler>=0.10 in /usr/lib/python3/dist-packages (from matplotlib->gradio==3.1.4->-r requirements.txt (line 17)) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/lib/python3/dist-packages (from matplotlib->gradio==3.1.4->-r requirements.txt (line 17)) (1.0.1)
Requirement already satisfied: pynacl>=1.0.1 in /usr/lib/python3/dist-packages (from paramiko->gradio==3.1.4->-r requirements.txt (line 17)) (1.3.0)
Requirement already satisfied: cryptography>=2.5 in /usr/lib/python3/dist-packages (from paramiko->gradio==3.1.4->-r requirements.txt (line 17)) (2.8)
Collecting bcrypt>=3.1.3
  Downloading bcrypt-4.0.0-cp36-abi3-manylinux_2_28_x86_64.whl (594 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 594.4/594.4 kB 110.9 MB/s eta 0:00:00
Collecting smmap<6,>=3.0.1
  Downloading smmap-5.0.0-py3-none-any.whl (24 kB)
Collecting uc-micro-py
  Downloading uc_micro_py-1.0.1-py3-none-any.whl (6.2 kB)
Collecting tzdata
  Downloading tzdata-2022.4-py2.py3-none-any.whl (336 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 336.7/336.7 kB 89.7 MB/s eta 0:00:00
Building wheels for collected packages: albumentations, pudb, fire, antlr4-python3-runtime, test-tube, imgaug, urwid, validators, ffmpy, python-multipart
  Building wheel for albumentations (setup.py) ... done
  Created wheel for albumentations: filename=albumentations-0.4.3-py3-none-any.whl size=60766 sha256=05eadca09805575f3074ec44f877ba2879dcd3e6da2704396d8a47032fc14ec8
  Stored in directory: /home/ubuntu/.cache/pip/wheels/a0/37/4e/0bd417ba6a58f73329b825623d8c949e8e4ac2cdbd252b786d
  Building wheel for pudb (setup.py) ... done
  Created wheel for pudb: filename=pudb-2019.2-py3-none-any.whl size=63230 sha256=9ca404433ce8115d7d129aa2d81f91488890a78900d6b61012d80c1a45040a10
  Stored in directory: /home/ubuntu/.cache/pip/wheels/48/83/f1/d8a09d401e2512bfda01ac9fc1b334885f9ddf51617c1c49f1
  Building wheel for fire (setup.py) ... done
  Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115925 sha256=2d8fe917d7538e2606619c759f146da90a339d4ddae68133d836c34179a99e3f
  Stored in directory: /home/ubuntu/.cache/pip/wheels/1f/10/06/2a990ee4d73a8479fe2922445e8a876d38cfbfed052284c6a1
  Building wheel for antlr4-python3-runtime (setup.py) ... done
  Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.8-py3-none-any.whl size=141230 sha256=b99f1dfd4d32e815c712ea6aae7b22f9111e3bcd259682fe9af85352c3255528
  Stored in directory: /home/ubuntu/.cache/pip/wheels/c8/d0/ab/d43c02eaddc5b9004db86950802442ad9a26f279c619e28da0
  Building wheel for test-tube (setup.py) ... done
  Created wheel for test-tube: filename=test_tube-0.7.5-py3-none-any.whl size=25358 sha256=393e571100e363b2a62ec66009dcb1598e247bbb96e1bec64864272f3f19c4eb
  Stored in directory: /home/ubuntu/.cache/pip/wheels/95/b0/3a/00ea66dbb0d9ce470ce1bdcb854a6fa030c279c316cb27ca9e
  Building wheel for imgaug (setup.py) ... done
  Created wheel for imgaug: filename=imgaug-0.2.6-py3-none-any.whl size=654018 sha256=73dea8c357648a92587178629acbbe4fcd8fe896308f878a5f402c5bfde2171b
  Stored in directory: /home/ubuntu/.cache/pip/wheels/41/23/e8/b1016c275f713978d312621da3c4f55920ec4297798aba8a5a
  Building wheel for urwid (setup.py) ... done
  Created wheel for urwid: filename=urwid-2.1.2-cp38-cp38-linux_x86_64.whl size=259151 sha256=1dc46a35bdabf9cbaeb34bda8f062e82235b0fb9c9e540f2b9913c5becccb81c
  Stored in directory: /home/ubuntu/.cache/pip/wheels/28/71/e4/38b5d81438105d0e3db5016cf2eea6fa796d89d96a04451d4d
  Building wheel for validators (setup.py) ... done
  Created wheel for validators: filename=validators-0.20.0-py3-none-any.whl size=19565 sha256=ce8b7264206c377ed61a5137f9478235b238b5eef59f2d9d3d57f465081c718e
  Stored in directory: /home/ubuntu/.cache/pip/wheels/19/09/72/3eb74d236bb48bd0f3c6c3c83e4e0c5bbfcbcad7c6c3539db8
  Building wheel for ffmpy (setup.py) ... done
  Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4709 sha256=9ba5c4d9c1060bdf85262d3b1f0b5f642fee4561b75d8cd791fd81b560e5820b
  Stored in directory: /home/ubuntu/.cache/pip/wheels/ff/5b/59/913b443e7369dc04b61f607a746b6f7d83fb65e2e19fcc958d
  Building wheel for python-multipart (setup.py) ... done
  Created wheel for python-multipart: filename=python_multipart-0.0.5-py3-none-any.whl size=31669 sha256=697dd66bbef6fe75521d762322eca5730884ea9ccd9648fef2cd4f83470c2ba9
  Stored in directory: /home/ubuntu/.cache/pip/wheels/9e/fc/1c/cf980e6413d3ee8e70cd8f39e2366b0f487e3e221aeb452eb0
Successfully built albumentations pudb fire antlr4-python3-runtime test-tube imgaug urwid validators ffmpy python-multipart
Installing collected packages: wcwidth, urwid, tokenizers, rfc3986, pydub, monotonic, ffmpy, einops, commonmark, braceexpand, antlr4-python3-runtime, xxhash, webdataset, watchdog, validators, urllib3, uc-micro-py, tzdata, toml, taming-transformers, smmap, semver, rich, regex, python-multipart, pympler, pyDeprecate, pycryptodome, pyarrow, pudb, protobuf, orjson, opencv-python-headless, opencv-python, omegaconf, multidict, mdurl, latent-diffusion, imageio-ffmpeg, imageio, h11, ftfy, fsspec, frozenlist, fire, dill, bcrypt, backports.zoneinfo, backoff, async-timeout, yarl, uvicorn, torchmetrics, torch-fidelity, starlette, pytz-deprecation-shim, pydeck, paramiko, multiprocess, markdown-it-py, linkify-it-py, kornia, imgaug, httpcore, gitdb, clip, aiosignal, tzlocal, test-tube, responses, mdit-py-plugins, huggingface-hub, httpx, gitpython, fastapi, analytics-python, altair, albumentations, aiohttp, transformers, streamlit, diffusers, pytorch-lightning, gradio, datasets
  Running setup.py develop for taming-transformers
  Running setup.py develop for latent-diffusion
  Running setup.py develop for clip
Successfully installed aiohttp-3.8.3 aiosignal-1.2.0 albumentations-0.4.3 altair-4.2.0 analytics-python-1.4.0 antlr4-python3-runtime-4.8 async-timeout-4.0.2 backoff-1.10.0 backports.zoneinfo-0.2.1 bcrypt-4.0.0 braceexpand-0.1.7 clip commonmark-0.9.1 datasets-2.4.0 diffusers-0.3.0 dill-0.3.5.1 einops-0.3.0 fastapi-0.85.0 ffmpy-0.3.0 fire-0.4.0 frozenlist-1.3.1 fsspec-2022.8.2 ftfy-6.1.1 gitdb-4.0.9 gitpython-3.1.28 gradio-3.1.4 h11-0.12.0 httpcore-0.15.0 httpx-0.23.0 huggingface-hub-0.10.0 imageio-2.9.0 imageio-ffmpeg-0.4.2 imgaug-0.2.6 kornia-0.6.0 latent-diffusion linkify-it-py-1.0.3 markdown-it-py-2.1.0 mdit-py-plugins-0.3.1 mdurl-0.1.2 monotonic-1.6 multidict-6.0.2 multiprocess-0.70.13 omegaconf-2.1.1 opencv-python-4.5.5.64 opencv-python-headless-4.6.0.66 orjson-3.8.0 paramiko-2.11.0 protobuf-3.20.3 pudb-2019.2 pyDeprecate-0.3.1 pyarrow-9.0.0 pycryptodome-3.15.0 pydeck-0.8.0b3 pydub-0.25.1 pympler-1.0.1 python-multipart-0.0.5 pytorch-lightning-1.4.2 pytz-deprecation-shim-0.1.0.post0 regex-2022.9.13 responses-0.18.0 rfc3986-1.5.0 rich-12.6.0 semver-2.13.0 smmap-5.0.0 starlette-0.20.4 streamlit-1.13.0 taming-transformers test-tube-0.7.5 tokenizers-0.12.1 toml-0.10.2 torch-fidelity-0.3.0 torchmetrics-0.6.0 transformers-4.22.2 tzdata-2022.4 tzlocal-4.2 uc-micro-py-1.0.1 urllib3-1.26.12 urwid-2.1.2 uvicorn-0.18.3 validators-0.20.0 watchdog-2.1.9 wcwidth-0.2.5 webdataset-0.2.5 xxhash-3.0.0 yarl-1.8.1
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: keras in /usr/lib/python3/dist-packages (2.9.0)
Collecting keras
  Downloading keras-2.10.0-py2.py3-none-any.whl (1.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 113.3 MB/s eta 0:00:00
Installing collected packages: keras
Successfully installed keras-2.10.0
In [3]:
!nvidia-smi
Fri Oct  7 07:47:59 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:06:00.0 Off |                    0 |
| N/A   33C    P0    46W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
In [8]:
# Check the dataset
# from datasets import load_dataset
# ds = load_dataset("lambdalabs/pokemon-blip-captions", split="train")
# sample = ds[0]
# display(sample["image"].resize((256, 256)))
# print(sample["text"])
from paris_dataloader import ParisDataset

ds = ParisDataset("../VLoD/", 512, 512, None, [])

# sample = ds[0]
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-8-df3cd5287312> in <module>
      9 ds = ParisDataset("../VLoD/", 512, 512, None, [])
     10 
---> 11 sample = ds[0]

/usr/lib/python3/dist-packages/torch/utils/data/dataset.py in __getitem__(self, index)
     66 
     67     def __getitem__(self, index) -> T_co:
---> 68         raise NotImplementedError
     69 
     70     def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':

NotImplementedError: 

To get the weights you need to you'll need to go to the model card, read the license and tick the checkbox if you agree.

In [5]:
!pip install huggingface_hub
from huggingface_hub import notebook_login

notebook_login()
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: huggingface_hub in /home/ubuntu/.local/lib/python3.8/site-packages (0.10.0)
Requirement already satisfied: packaging>=20.9 in /home/ubuntu/.local/lib/python3.8/site-packages (from huggingface_hub) (21.3)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ubuntu/.local/lib/python3.8/site-packages (from huggingface_hub) (4.3.0)
Requirement already satisfied: requests in /home/ubuntu/.local/lib/python3.8/site-packages (from huggingface_hub) (2.28.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from huggingface_hub) (5.3.1)
Requirement already satisfied: tqdm in /home/ubuntu/.local/lib/python3.8/site-packages (from huggingface_hub) (4.64.1)
Requirement already satisfied: filelock in /usr/lib/python3/dist-packages (from huggingface_hub) (3.0.12)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/lib/python3/dist-packages (from packaging>=20.9->huggingface_hub) (2.4.6)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ubuntu/.local/lib/python3.8/site-packages (from requests->huggingface_hub) (1.26.12)
Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests->huggingface_hub) (2019.11.28)
Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests->huggingface_hub) (2.8)
Requirement already satisfied: charset-normalizer<3,>=2 in /home/ubuntu/.local/lib/python3.8/site-packages (from requests->huggingface_hub) (2.1.1)
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
In [6]:
from huggingface_hub import hf_hub_download
ckpt_path = hf_hub_download(repo_id="CompVis/stable-diffusion-v-1-4-original", filename="sd-v1-4-full-ema.ckpt", use_auth_token=True)
Downloading:   0%|          | 0.00/7.70G [00:00<?, ?B/s]
In [14]:
ckpt_path
Out[14]:
'/home/ubuntu/.cache/huggingface/hub/models--CompVis--stable-diffusion-v-1-4-original/snapshots/0834a76f88354683d3f7ef271cadd28f4757a8cc/sd-v1-4-full-ema.ckpt'

Set your parameters below depending on your GPU setup, the settings below were used for training on a 2xA6000 machine, (the A6000 has 48GB of VRAM). On this set up good results are achieved in around 6 hours.

You can make up for using smaller batches or fewer gpus by accumulating batches:

total batch size = batach size * n gpus * accumulate batches

In [12]:
# 2xA6000:
BATCH_SIZE = 2
N_GPUS = 1
ACCUMULATE_BATCHES = 1

gpu_list = ",".join((str(x) for x in range(N_GPUS))) + ","
print(f"Using GPUs: {gpu_list}")
Using GPUs: 0,
In [11]:
# Run training
!(python main.py \
    -t \
    --base configs/stable-diffusion/paris.yaml \
    --gpus "$gpu_list" \
    --scale_lr False \
    --num_nodes 1 \
    --check_val_every_n_epoch 10 \
    --finetune_from "$ckpt_path" \
    data.params.batch_size="$BATCH_SIZE" \
    lightning.trainer.accumulate_grad_batches="$ACCUMULATE_BATCHES" \
    data.params.validation.params.n_gpus="$NUM_GPUS" \
)
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
Moving 0 files to the new cache system
0it [00:00, ?it/s]
usage: main.py [-h] [--finetune_from [FINETUNE_FROM]] [-n [NAME]]
               [-r [RESUME]] [-b [base_config.yaml [base_config.yaml ...]]]
               [-t [TRAIN]] [--no-test [NO_TEST]] [-p PROJECT] [-d [DEBUG]]
               [-s SEED] [-f POSTFIX] [-l LOGDIR] [--scale_lr [SCALE_LR]]
               [--logger [LOGGER]]
               [--checkpoint_callback [CHECKPOINT_CALLBACK]]
               [--default_root_dir DEFAULT_ROOT_DIR]
               [--gradient_clip_val GRADIENT_CLIP_VAL]
               [--gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM]
               [--process_position PROCESS_POSITION] [--num_nodes NUM_NODES]
               [--num_processes NUM_PROCESSES] [--devices DEVICES]
               [--gpus GPUS] [--auto_select_gpus [AUTO_SELECT_GPUS]]
               [--tpu_cores TPU_CORES] [--ipus IPUS]
               [--log_gpu_memory LOG_GPU_MEMORY]
               [--progress_bar_refresh_rate PROGRESS_BAR_REFRESH_RATE]
               [--overfit_batches OVERFIT_BATCHES]
               [--track_grad_norm TRACK_GRAD_NORM]
               [--check_val_every_n_epoch CHECK_VAL_EVERY_N_EPOCH]
               [--fast_dev_run [FAST_DEV_RUN]]
               [--accumulate_grad_batches ACCUMULATE_GRAD_BATCHES]
               [--max_epochs MAX_EPOCHS] [--min_epochs MIN_EPOCHS]
               [--max_steps MAX_STEPS] [--min_steps MIN_STEPS]
               [--max_time MAX_TIME]
               [--limit_train_batches LIMIT_TRAIN_BATCHES]
               [--limit_val_batches LIMIT_VAL_BATCHES]
               [--limit_test_batches LIMIT_TEST_BATCHES]
               [--limit_predict_batches LIMIT_PREDICT_BATCHES]
               [--val_check_interval VAL_CHECK_INTERVAL]
               [--flush_logs_every_n_steps FLUSH_LOGS_EVERY_N_STEPS]
               [--log_every_n_steps LOG_EVERY_N_STEPS]
               [--accelerator ACCELERATOR] [--sync_batchnorm [SYNC_BATCHNORM]]
               [--precision PRECISION] [--weights_summary WEIGHTS_SUMMARY]
               [--weights_save_path WEIGHTS_SAVE_PATH]
               [--num_sanity_val_steps NUM_SANITY_VAL_STEPS]
               [--truncated_bptt_steps TRUNCATED_BPTT_STEPS]
               [--resume_from_checkpoint RESUME_FROM_CHECKPOINT]
               [--profiler PROFILER] [--benchmark [BENCHMARK]]
               [--deterministic [DETERMINISTIC]]
               [--reload_dataloaders_every_n_epochs RELOAD_DATALOADERS_EVERY_N_EPOCHS]
               [--reload_dataloaders_every_epoch [RELOAD_DATALOADERS_EVERY_EPOCH]]
               [--auto_lr_find [AUTO_LR_FIND]]
               [--replace_sampler_ddp [REPLACE_SAMPLER_DDP]]
               [--terminate_on_nan [TERMINATE_ON_NAN]]
               [--auto_scale_batch_size [AUTO_SCALE_BATCH_SIZE]]
               [--prepare_data_per_node [PREPARE_DATA_PER_NODE]]
               [--plugins PLUGINS] [--amp_backend AMP_BACKEND]
               [--amp_level AMP_LEVEL]
               [--distributed_backend DISTRIBUTED_BACKEND]
               [--move_metrics_to_cpu [MOVE_METRICS_TO_CPU]]
               [--multiple_trainloader_mode MULTIPLE_TRAINLOADER_MODE]
               [--stochastic_weight_avg [STOCHASTIC_WEIGHT_AVG]]
main.py: error: argument --gpus: invalid _gpus_allowed_type value: ''
In [ ]:
# Run the model
!(python scripts/txt2img.py \
    --prompt 'robotic cat with wings' \
    --outdir 'outputs/generated_pokemon' \
    --H 512 --W 512 \
    --n_samples 4 \
    --config 'configs/stable-diffusion/pokemon.yaml' \
    --ckpt 'path/to/your/checkpoint')