Compare commits
3 Commits
35c61c0c5b
...
73ce5c4bfd
Author | SHA1 | Date |
---|---|---|
Ruben van de Ven | 73ce5c4bfd | |
Ruben van de Ven | c1f7429ca1 | |
Ruben van de Ven | e2db2688e0 |
File diff suppressed because one or more lines are too long
18
Dockerfile
18
Dockerfile
|
@ -12,6 +12,24 @@ ENV PYTHONDONTWRITEBYTECODE 1
|
|||
ENV PYTHONUNBUFFERED 1
|
||||
|
||||
RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
|
||||
RUN pip install ipywidgets
|
||||
|
||||
#When X11 forwarding matplotlib
|
||||
#RUN pip install cairocffi
|
||||
|
||||
|
||||
RUN apt-get update -y
|
||||
ENV TZ=Europe/Amsterdam
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||
#RUN apt-get install -y libcairo2 python3-gi python3-gi-cairo gir1.2-gtk-3.0
|
||||
RUN apt-get install -y libgirepository1.0-dev gcc libcairo2-dev pkg-config python3-dev gir1.2-gtk-3.0
|
||||
RUN pip install pycairo
|
||||
RUN pip install PyGObject
|
||||
RUN apt-get install -y mesa-utils
|
||||
|
||||
# ffmpeg for cv2 video creation
|
||||
RUN apt-get install -y ffmpeg
|
||||
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
3102
Stylegan3.ipynb
3102
Stylegan3.ipynb
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,631 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"id": "f4ee99c4-9c28-4fe4-9408-e130a0d446d3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from runs import Run, Snapshot, get_projections_in_dir, get_runs_in_dir\n",
|
||||
"from scipy.ndimage.filters import uniform_filter1d\n",
|
||||
"import cv2\n",
|
||||
"from PIL import Image\n",
|
||||
"import os"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d362cae7-c5c9-4127-b221-92c859aa1620",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def is_main():\n",
|
||||
" \"\"\"Return True if this notebook is being run by calling\n",
|
||||
" %run in another notebook, False otherwise.\n",
|
||||
" works around jupyter bug: https://github.com/ipython/ipython/issues/10967\n",
|
||||
" \"\"\"\n",
|
||||
" try:\n",
|
||||
" __file__\n",
|
||||
" # __file__ has been defined, so this notebook is \n",
|
||||
" # being run in a parent notebook\n",
|
||||
" return True\n",
|
||||
"\n",
|
||||
" except NameError:\n",
|
||||
" # __file__ has not been defined, so this notebook is \n",
|
||||
" # not being run in a parent notebook\n",
|
||||
" return False\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "2544f23f-3b2c-4d29-95a6-031cecab1e08",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"args = {\n",
|
||||
" 'runs_dir': 'training-runs', \n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "4c428611-8d75-4f9a-ae5a-7960e7b01470",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"runs = get_runs_in_dir(args['runs_dir'])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7f77f181-059a-4568-8088-1eb86a1a172a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"See [Snapshot_images.ipynb](Snapshot_images.ipynb) for examples of each run/snapshot."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "9a5c4a18-0389-4d36-989e-7a120db590ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# snapshot = runs[3].snapshots[70]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c3aa8404-aeb0-4f63-b4b4-c278f0cf3766",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Plot run metrics\n",
|
||||
"\n",
|
||||
"We can plot the progress of the metrics (fid) for each run. Sommige runs zijn een vervolg op een eerder run. Dit zou kunnen zijn om het netwerk een voorgetraind startpunt te geven, maar in dit geval was het vooral omdat de training zo nu en dan was gestopt en weer herstart (wat een 'nieuwe' run geeft."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "c180db45-9bf1-4f0a-ab55-62b476a5897c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# importing package\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"def plot_runs(runs, dpi=300, palette=None):\n",
|
||||
" \n",
|
||||
" plt.figure(dpi=dpi)\n",
|
||||
" plt.yscale('log')\n",
|
||||
" for i, run in enumerate(runs):\n",
|
||||
" x = [s.cumulative_iteration for s in run.snapshots]\n",
|
||||
" y = [s.fid for s in run.snapshots]\n",
|
||||
" # plot lines\n",
|
||||
" c = palette[i%len(palette)] if palette else None\n",
|
||||
" plt.plot(x, y, label = f\"{i} {run.id}\", c=c)\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" return plt\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "22c74f41-65e7-461a-9094-f0b3d8738c82",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'runs' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/tmp/ipykernel_1/862876608.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# def is_main():\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mplot\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplot_runs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mruns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbbox_to_anchor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"lower left\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mplot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'runs' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def is_main():\n",
|
||||
" plot = plot_runs(runs)\n",
|
||||
" plot.legend(bbox_to_anchor=(1,0), loc=\"lower left\")\n",
|
||||
" plot.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "b28462dc-63a7-4630-8cf2-44eb2869661a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_stats(stat_ids, runs, dpi=300, palette=None): \n",
|
||||
" fig2, axes = plt.subplots(nrows=1, ncols=len(stat_ids), figsize=(10*len(stat_ids), 9), dpi=dpi)\n",
|
||||
"\n",
|
||||
" for i, ax in enumerate(axes):\n",
|
||||
" ax.set_xlabel('kimg')\n",
|
||||
" ax.set_ylabel(stat_ids[i])\n",
|
||||
" ax.set_yscale('symlog', linthresh=1) # 0-1: linear, >1: log scale\n",
|
||||
"\n",
|
||||
" for i, run in enumerate(runs):\n",
|
||||
" stats = [\n",
|
||||
" [\n",
|
||||
" s['Progress/kimg']['mean'] + run.kimg_offset\n",
|
||||
" ] + [[s[sid]['mean'], s[sid]['std']] for sid in stat_ids]\n",
|
||||
" for s in run.get_stats()\n",
|
||||
" ]\n",
|
||||
" x = [ s[0] for s in stats ]\n",
|
||||
" \n",
|
||||
" c = palette[i%len(palette)] if palette else None\n",
|
||||
" # smooth slightly for better readability\n",
|
||||
" for i, stat_id in enumerate(stat_ids):\n",
|
||||
" error = [s[i+1][1] for s in stats]\n",
|
||||
" y = uniform_filter1d([s[i+1][0] for s in stats], size=20)\n",
|
||||
" axes[i].plot(x, y, label = f\"{i} {run.id}\", c=c)\n",
|
||||
" # draw std dev:\n",
|
||||
" # axes[i].fill_between(x, y-error, y+error,\n",
|
||||
" # alpha=0.2,\n",
|
||||
" # antialiased=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # print(x,y)\n",
|
||||
" # x = [s.cumulative_iteration for s in stats]\n",
|
||||
" # y = [s.fid for s in run.stats]\n",
|
||||
" # # plot lines\n",
|
||||
" # ax2.plot(x, y2, label = f\"{i} {run.id}\")\n",
|
||||
" return plt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8bb5a2db-4e04-4402-953c-262a4a010c96",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def is_main():\n",
|
||||
" plot = plot_stats([\n",
|
||||
" 'Loss/D/loss',\n",
|
||||
" 'Loss/G/loss',\n",
|
||||
" ], runs)\n",
|
||||
" plot.legend()\n",
|
||||
" plot.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9c4d937f-0035-4760-a7de-93080dfe5438",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# excerpts\n",
|
||||
"\n",
|
||||
"Als we de runs met de laagste FID scores bekijken bekijken krijgen we een beeld van de kwaliteit van de netwerken.\n",
|
||||
"\n",
|
||||
"Wat vooral opvalt is dat het netwerk met gecropte beelden (00014+16) de meest kleurrijke beelden geeft en dus niet convergeert naar een beige-grijs zoals veel van de andere netwerken."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "461480c4-d4fe-4fc8-a7fa-6e66d468595a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def is_main():\n",
|
||||
" display(\n",
|
||||
" runs[-1].snapshots[-10].iteration,\n",
|
||||
" runs[-1].snapshots[-10].get_preview_img(8,1),\n",
|
||||
" runs[-1].snapshots[-2].iteration,\n",
|
||||
" runs[-1].snapshots[-2].get_preview_img(8,1),\n",
|
||||
" runs[-1].snapshots[-1].iteration,\n",
|
||||
" runs[-1].snapshots[-1].get_preview_img(8,1),\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9782b289-101d-4bc4-b9e4-8d0838dec09b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def is_main():\n",
|
||||
" display(\n",
|
||||
" runs[3].snapshots[-1].get_preview_img(4,1),\n",
|
||||
" runs[2].snapshots[-1].get_preview_img(4,1),\n",
|
||||
" runs[5].snapshots[-1].get_preview_img(4,1)\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "22668f34-fa75-4d05-abeb-c4b41bec8093",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Stylegan 3 functions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "de9c301f-5a20-4eb7-87c1-83c8e67508f0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Helper functions for Stylegan 3 operations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "b31b0b4b-7fde-4a8b-812e-5545c68ce64d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ab78ac4c-5e2d-456c-9afd-9639ad963b51",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Converts seeds to `z` space, `z`-space to `w`-space and use `w`-space to generate images and generated images to jupyter-widgets for visualisation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "bb0593f9-05df-47e7-9555-0af43951a49b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# adapted from https://github.com/dvschultz/stylegan2-ada-pytorch/blob/9b6750b96dc9841816e8ac57b05f395d0f23c30d/generate.py\n",
|
||||
"\n",
|
||||
"def seeds_to_zs(G,seeds):\n",
|
||||
" zs = []\n",
|
||||
" for seed_idx, seed in enumerate(seeds):\n",
|
||||
" z = np.random.RandomState(seed).randn(1, G.z_dim)\n",
|
||||
" zs.append(z)\n",
|
||||
" return zs\n",
|
||||
"\n",
|
||||
"def zs_to_ws(G,device,label,truncation_psi,zs):\n",
|
||||
" ws = []\n",
|
||||
" for z in zs:\n",
|
||||
" z = torch.from_numpy(z).to(device)\n",
|
||||
" w = G.mapping(z, label, truncation_psi=truncation_psi, truncation_cutoff=8)\n",
|
||||
" ws.append(w)\n",
|
||||
" return ws\n",
|
||||
"\n",
|
||||
"def images(G,device,inputs,space,truncation_psi,label,noise_mode,start=None,stop=None):\n",
|
||||
" \"\"\"Generate image for z or w space image (deprecated)\"\"\"\n",
|
||||
" if(start is not None and stop is not None):\n",
|
||||
" tp = start\n",
|
||||
" tp_i = (stop-start)/len(inputs)\n",
|
||||
"\n",
|
||||
" for idx, i in enumerate(inputs):\n",
|
||||
" # print('Generating image for frame %d/%d ...' % (idx, len(inputs)))\n",
|
||||
" \n",
|
||||
" if (space=='z'):\n",
|
||||
" z = torch.from_numpy(i).to(device)\n",
|
||||
" if(start is not None and stop is not None):\n",
|
||||
" img = G(z, label, truncation_psi=tp, noise_mode=noise_mode)\n",
|
||||
" tp = tp+tp_i\n",
|
||||
" else:\n",
|
||||
" img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)\n",
|
||||
" else:\n",
|
||||
" if len(i.shape) == 2: \n",
|
||||
" i = torch.from_numpy(i).unsqueeze(0).to(device)\n",
|
||||
" img = G.synthesis(i, noise_mode=noise_mode, force_fp32=True)\n",
|
||||
" img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)\n",
|
||||
" yield f\"{idx:04d}\", Image.fromarray(img[0].cpu().numpy(), 'RGB')\n",
|
||||
"\n",
|
||||
"def w_to_img(G, device, noise_mode, w):\n",
|
||||
" img = G.synthesis(w, noise_mode=noise_mode, force_fp32=True)\n",
|
||||
" img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)\n",
|
||||
" return Image.fromarray(img[0].cpu().numpy(), 'RGB')\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "efafbb3a-7d2c-4b6a-a626-f06dd1b2b1ec",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Interpolation of vectors, to browse latent space."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "014449b1-1293-49bb-90bf-a1949315183c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def interpolate(G,device,projected_w,seeds,random_seed,space,truncation_psi,label,frames,noise_mode,outdir,interpolation,\n",
|
||||
" easing, diameter, start=None,stop=None):\n",
|
||||
" if(interpolation=='noiseloop' or interpolation=='circularloop'):\n",
|
||||
" if seeds is not None:\n",
|
||||
" print(f'Warning: interpolation type: \"{interpolation}\" doesn’t support set seeds.')\n",
|
||||
"\n",
|
||||
" if(interpolation=='noiseloop'):\n",
|
||||
" points = noiseloop(frames, diameter, random_seed)\n",
|
||||
" elif(interpolation=='circularloop'):\n",
|
||||
" points = circularloop(frames, diameter, random_seed, seeds)\n",
|
||||
"\n",
|
||||
" else:\n",
|
||||
" if projected_w is not None:\n",
|
||||
" points = np.load(projected_w)['w']\n",
|
||||
" else:\n",
|
||||
" # get zs from seeds\n",
|
||||
" points = seeds_to_zs(G,seeds) \n",
|
||||
" # convert to ws\n",
|
||||
" if(space=='w'):\n",
|
||||
" points = zs_to_ws(G,device,label,truncation_psi,points)\n",
|
||||
"\n",
|
||||
" # get interpolation points\n",
|
||||
" if(interpolation=='linear'):\n",
|
||||
" points = line_interpolate(points,frames,easing)\n",
|
||||
" elif(interpolation=='slerp'):\n",
|
||||
" points = slerp_interpolate(points,frames)\n",
|
||||
" \n",
|
||||
" # generate frames\n",
|
||||
" for idx, img in images(G,device,points,space,truncation_psi,label,noise_mode,outdir,start,stop):\n",
|
||||
" yield idx, img\n",
|
||||
"\n",
|
||||
"# slightly modified version of\n",
|
||||
"# https://github.com/PDillis/stylegan2-fun/blob/master/run_generator.py#L399\n",
|
||||
"def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):\n",
|
||||
" '''\n",
|
||||
" Spherical linear interpolation\n",
|
||||
" Args:\n",
|
||||
" t (float/np.ndarray): Float value between 0.0 and 1.0\n",
|
||||
" v0 (np.ndarray): Starting vector\n",
|
||||
" v1 (np.ndarray): Final vector\n",
|
||||
" DOT_THRESHOLD (float): Threshold for considering the two vectors as\n",
|
||||
" colineal. Not recommended to alter this.\n",
|
||||
" Returns:\n",
|
||||
" v2 (np.ndarray): Interpolation vector between v0 and v1\n",
|
||||
" '''\n",
|
||||
" v0 = v0.cpu().detach().numpy() if hasattr(v0, 'cpu') else v0\n",
|
||||
" v1 = v1.cpu().detach().numpy() if hasattr(v1, 'cpu') else v1\n",
|
||||
" # Copy the vectors to reuse them later\n",
|
||||
" v0_copy = np.copy(v0)\n",
|
||||
" v1_copy = np.copy(v1)\n",
|
||||
" # Normalize the vectors to get the directions and angles\n",
|
||||
" v0 = v0 / np.linalg.norm(v0)\n",
|
||||
" v1 = v1 / np.linalg.norm(v1)\n",
|
||||
" # Dot product with the normalized vectors (can't use np.dot in W)\n",
|
||||
" dot = np.sum(v0 * v1)\n",
|
||||
" # If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp\n",
|
||||
" if np.abs(dot) > DOT_THRESHOLD:\n",
|
||||
" return lerp(t, v0_copy, v1_copy)\n",
|
||||
" # Calculate initial angle between v0 and v1\n",
|
||||
" theta_0 = np.arccos(dot)\n",
|
||||
" sin_theta_0 = np.sin(theta_0)\n",
|
||||
" # Angle at timestep t\n",
|
||||
" theta_t = theta_0 * t\n",
|
||||
" sin_theta_t = np.sin(theta_t)\n",
|
||||
" # Finish the slerp algorithm\n",
|
||||
" s0 = np.sin(theta_0 - theta_t) / sin_theta_0\n",
|
||||
" s1 = sin_theta_t / sin_theta_0\n",
|
||||
" v2 = s0 * v0_copy + s1 * v1_copy\n",
|
||||
" return torch.from_numpy(v2).to(\"cuda\")\n",
|
||||
"\n",
|
||||
"def slerp_interpolate(zs, steps):\n",
|
||||
" out = []\n",
|
||||
" for i in range(len(zs)-1):\n",
|
||||
" for index in range(steps):\n",
|
||||
" fraction = index/float(steps)\n",
|
||||
" out.append(slerp(fraction,zs[i],zs[i+1]))\n",
|
||||
" return out\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "60be2986-914d-4032-8554-cdd3ab52f827",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Project an imag to the latent space"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f802546c-6e74-440f-8d88-9e5a166b319d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess, operator\n",
|
||||
"\n",
|
||||
"def project_img_to_z(snapshot: Snapshot, image_filename: str, steps=1000, replace_if_exists=False) -> dict:\n",
|
||||
" \n",
|
||||
" # imagenr = image_filename[-12:-4]\n",
|
||||
" image_name = image_filename[:-4]\n",
|
||||
" runnr = snapshot.run.as_nr\n",
|
||||
" # !!python pbaylies_projector.py --network $snapshot_pkl --outdir out/projections/$runnr-$imagenr --target-image $image_filename --use-clip=False\n",
|
||||
" \n",
|
||||
" if replace_if_exists or not os.path.exists(f\"out/projections/{runnr}/{image_name}/proj.png\"):\n",
|
||||
" process = subprocess.Popen([\n",
|
||||
" \"python\", \"pbaylies_projector.py\",\n",
|
||||
" \"--network\" , snapshot.pkl_path,\n",
|
||||
" \"--outdir\", f\"out/projections/{runnr}/{image_name}\",\n",
|
||||
" \"--target-image\", image_filename,\n",
|
||||
" \"--use-clip\", \"False\",\n",
|
||||
" \"--num-steps\", str(steps),\n",
|
||||
" \"--save-video\", \"False\"\n",
|
||||
" ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)\n",
|
||||
" stdout, stderr = process.communicate()\n",
|
||||
" lines = stdout.split(\"\\n\")\n",
|
||||
" lossess_and_distances = [operator.itemgetter(-1,-3)(line.split(\" \")) for line in lines if line.startswith(\"step\")]\n",
|
||||
" print(stderr)\n",
|
||||
" loss, dist = lossess_and_distances[-1]\n",
|
||||
" else:\n",
|
||||
" # TODO: get loss and dist from somewhere? (currently not using it much)\n",
|
||||
" loss, dist = (None, None)\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"img\": f\"out/projections/{runnr}-{imagenr}/proj.png\",\n",
|
||||
" \"src_img\": f\"out/projections/{runnr}-{imagenr}/target.png\",\n",
|
||||
" \"src\": image_filename,\n",
|
||||
" \"npz\": f\"out/projections/{runnr}-{imagenr}/projected_w.npz\",\n",
|
||||
" \"loss\": loss,\n",
|
||||
" \"dist\": dist\n",
|
||||
" }\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6f10a7eb-5056-41da-9cfd-894a1c772e0d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Displaying Videos and Images"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"id": "f7ccb1c5-13c3-4ff9-9030-125622ccf93a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def img_to_widget(img):\n",
|
||||
" buff = io.BytesIO()\n",
|
||||
" img.save(buff, format='png')\n",
|
||||
" \n",
|
||||
" return widgets.Image(value=buff.getvalue(), format='png', width=run_data['resolution'], height=run_data['resolution'])\n",
|
||||
"\n",
|
||||
"def video_to_widget(filename):\n",
|
||||
" with open(filename, 'rb') as fp:\n",
|
||||
" video = fp.read()\n",
|
||||
" return widgets.Video(value=video)\n",
|
||||
"\n",
|
||||
"def image_grid(imgs, cols=None, rows=None, margin = 10):\n",
|
||||
"# create image grid, if no size is given, put all on horizontal axis\n",
|
||||
" if cols is None or rows is None:\n",
|
||||
" cols = len(imgs)\n",
|
||||
" rows = 1\n",
|
||||
" \n",
|
||||
" w, h = imgs[0].size\n",
|
||||
" w, h = w+margin, h+margin\n",
|
||||
" grid = Image.new('RGB', size=(cols*w-margin, rows*h-margin))\n",
|
||||
" grid_w, grid_h = grid.size\n",
|
||||
" \n",
|
||||
" for i, img in enumerate(imgs):\n",
|
||||
" grid.paste(img, box=(i%cols*w, i//cols*h))\n",
|
||||
" return grid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c988053f-4a6d-4f66-a40d-593bf5a35765",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Postprocessing"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "573284e5-1c27-415c-9876-4a4a7e6fc704",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Sometimes we like to generate videos from a series of ws"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"id": "7fd5f974-3f44-4f92-b1b8-7564400229fa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from imageio_ffmpeg import write_frames\n",
|
||||
"\n",
|
||||
"def generator_to_video(generator, out_filename, fps, frame_size, quality):\n",
|
||||
" writer = write_frames(out_filename, frame_size, quality=quality) # size is (width, height)\n",
|
||||
" writer.send(None) # seed the generator\n",
|
||||
" print(os.path.abspath(out_filename))\n",
|
||||
" # output = cv2.VideoWriter(\n",
|
||||
" # out_filename,\n",
|
||||
" # # see http://mp4ra.org/#/codecs for codecs\n",
|
||||
" # cv2.VideoWriter_fourcc(*'vp09'),\n",
|
||||
" # fps,\n",
|
||||
" # frame_size)\n",
|
||||
" for frame in generator:\n",
|
||||
" if type(frame) is Image.Image:\n",
|
||||
" open_cv_image = np.array(frame) \n",
|
||||
" # Conve\n",
|
||||
" frame = open_cv_image[:, :, ::-1].copy()\n",
|
||||
" # output.write(frame)\n",
|
||||
" writer.send(frame)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" # output.release()\n",
|
||||
" writer.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"id": "6ff4c6db-b48f-4ff4-b08c-e85138f0f307",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "d0661104-72a7-4320-980a-1a702388659f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e2aa2c98-cf1e-465a-9455-fe4a02f145ad",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "paris-stylegan3",
|
||||
"language": "python",
|
||||
"name": "paris-stylegan3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
|
@ -26,6 +26,9 @@ import numpy as np
|
|||
import PIL.Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
PIL.Image.init() # required to initialise PIL.Image.EXTENSION
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def error(msg):
|
||||
|
@ -216,8 +219,11 @@ def open_mnist(images_gz: str, *, max_images: Optional[int]):
|
|||
def make_transform(
|
||||
transform: Optional[str],
|
||||
output_width: Optional[int],
|
||||
output_height: Optional[int]
|
||||
output_height: Optional[int],
|
||||
crop_width: Optional[int],
|
||||
crop_height: Optional[int]
|
||||
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
|
||||
|
||||
def scale(width, height, img):
|
||||
w = img.shape[1]
|
||||
h = img.shape[0]
|
||||
|
@ -249,17 +255,38 @@ def make_transform(
|
|||
canvas = np.zeros([width, width, 3], dtype=np.uint8)
|
||||
canvas[(width - height) // 2 : (width + height) // 2, :] = img
|
||||
return canvas
|
||||
|
||||
def scale_center_crop(width, height, crop_w, crop_h, img):
|
||||
return scale(width, height, img[(img.shape[0] - crop_w) // 2 : (img.shape[0] + crop_w) // 2, (img.shape[1] - crop_h) // 2 : (img.shape[1] + crop_h) // 2])
|
||||
|
||||
def scale_center_crop_wide(width, height, crop_w, crop_h, img):
|
||||
error('not implemented')
|
||||
|
||||
if transform is None:
|
||||
return functools.partial(scale, output_width, output_height)
|
||||
if transform == 'center-crop':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + 'transform')
|
||||
return functools.partial(center_crop, output_width, output_height)
|
||||
if transform == 'center-crop-wide':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + ' transform')
|
||||
return functools.partial(center_crop_wide, output_width, output_height)
|
||||
crop_width = output_width if crop_width is None else crop_width
|
||||
crop_height = output_height if crop_height is None else crop_height
|
||||
|
||||
if crop_width != output_width or crop_height != output_height:
|
||||
if transform is None:
|
||||
error ('must specify transform method (center-crop or center-crop-wide)')
|
||||
if transform == 'center-crop':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + 'transform')
|
||||
return functools.partial(scale_center_crop, output_height, output_width, crop_width, crop_height)
|
||||
if transform == 'center-crop-wide':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + ' transform')
|
||||
return functools.partial(scale_center_crop_wide, output_height, output_width, crop_width, crop_height)
|
||||
else:
|
||||
if transform is None:
|
||||
return functools.partial(scale, output_width, output_height)
|
||||
if transform == 'center-crop':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + 'transform')
|
||||
return functools.partial(center_crop, output_width, output_height)
|
||||
if transform == 'center-crop-wide':
|
||||
if (output_width is None) or (output_height is None):
|
||||
error ('must specify --resolution=WxH when using ' + transform + ' transform')
|
||||
return functools.partial(center_crop_wide, output_width, output_height)
|
||||
assert False, 'unknown transform'
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
@ -323,13 +350,15 @@ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None],
|
|||
@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
|
||||
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
|
||||
@click.option('--resolution', help='Output resolution (e.g., \'512x512\')', metavar='WxH', type=parse_tuple)
|
||||
@click.option('--crop-resolution', help='Resolution of crop (can be larger small than final resoltion) (e.g., \'600x600\')', metavar='WxH', type=parse_tuple)
|
||||
def convert_dataset(
|
||||
ctx: click.Context,
|
||||
source: str,
|
||||
dest: str,
|
||||
max_images: Optional[int],
|
||||
transform: Optional[str],
|
||||
resolution: Optional[Tuple[int, int]]
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
crop_resolution: Optional[Tuple[int, int]],
|
||||
):
|
||||
"""Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
|
||||
|
||||
|
@ -387,7 +416,7 @@ def convert_dataset(
|
|||
|
||||
\b
|
||||
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
|
||||
--transform=center-crop-wide --resolution=512x384
|
||||
--transform=center-crop-wide --resolution=512x384 --crop-resolution=600x600
|
||||
"""
|
||||
|
||||
PIL.Image.init() # type: ignore
|
||||
|
@ -399,7 +428,8 @@ def convert_dataset(
|
|||
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
||||
|
||||
if resolution is None: resolution = (None, None)
|
||||
transform_image = make_transform(transform, *resolution)
|
||||
if crop_resolution is None: crop_resolution = (None, None)
|
||||
transform_image = make_transform(transform, *resolution, *crop_resolution)
|
||||
|
||||
dataset_attrs = None
|
||||
|
||||
|
|
|
@ -0,0 +1,400 @@
|
|||
# Modified StyleGAN2 Projector with CLIP, addl. losses, kmeans, etc.
|
||||
# by Peter Baylies, 2021 -- @pbaylies on Twitter
|
||||
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Project given image to the latent space of pretrained network pickle."""
|
||||
|
||||
import copy
|
||||
import os
|
||||
from time import perf_counter
|
||||
|
||||
import click
|
||||
import imageio
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
from PIL import ImageFilter
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import dnnlib
|
||||
import legacy
|
||||
|
||||
image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
|
||||
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
|
||||
|
||||
def score_images(G, model, text, latents, device, label_class = 0, batch_size = 8):
|
||||
scores = []
|
||||
all_images = []
|
||||
for i in range(latents.shape[0]//batch_size):
|
||||
images = G.synthesis(torch.tensor(latents[i*batch_size:(i+1)*batch_size,:,:], dtype=torch.float32, device=device), noise_mode='const')
|
||||
with torch.no_grad():
|
||||
image_input = (torch.clamp(images, -1, 1) + 1) * 0.5
|
||||
image_input = F.interpolate(image_input, size=(256, 256), mode='area')
|
||||
image_input = image_input[:, :, 16:240, 16:240] # 256 -> 224, center crop
|
||||
image_input -= image_mean[None, :, None, None]
|
||||
image_input /= image_std[None, :, None, None]
|
||||
score = model(image_input, text)[0]
|
||||
scores.append(score.cpu().numpy())
|
||||
all_images.append(images.cpu().numpy())
|
||||
|
||||
scores = np.array(scores)
|
||||
scores = scores.reshape(-1, *scores.shape[2:]).squeeze()
|
||||
scores = 1 - scores / np.linalg.norm(scores)
|
||||
all_images = np.array(all_images)
|
||||
all_images = all_images.reshape(-1, *all_images.shape[2:])
|
||||
return scores, all_images
|
||||
|
||||
def project(
|
||||
G,
|
||||
target_image: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
|
||||
target_text,
|
||||
*,
|
||||
num_steps = 300,
|
||||
w_avg_samples = 8192,
|
||||
initial_learning_rate = 0.02,
|
||||
initial_latent = None,
|
||||
initial_noise_factor = 0.01,
|
||||
lr_rampdown_length = 0.10,
|
||||
lr_rampup_length = 0.5,
|
||||
noise_ramp_length = 0.75,
|
||||
latent_range = 2.0,
|
||||
max_noise = 0.5,
|
||||
min_threshold = 0.6,
|
||||
use_vgg = True,
|
||||
use_clip = True,
|
||||
use_pixel = True,
|
||||
use_penalty = True,
|
||||
use_center = True,
|
||||
regularize_noise_weight = 1e5,
|
||||
kmeans = True,
|
||||
kmeans_clusters = 64,
|
||||
verbose = False,
|
||||
device: torch.device
|
||||
):
|
||||
if target_image is not None:
|
||||
assert target_image.shape == (G.img_channels, G.img_resolution, G.img_resolution)
|
||||
else:
|
||||
use_vgg = False
|
||||
use_pixel = False
|
||||
|
||||
# reduce errors unless using clip
|
||||
if use_clip:
|
||||
import clip
|
||||
|
||||
def logprint(*args):
|
||||
if verbose:
|
||||
print(*args)
|
||||
|
||||
G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
|
||||
|
||||
# Compute w stats.
|
||||
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
|
||||
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
|
||||
labels = None
|
||||
if (G.mapping.c_dim):
|
||||
labels = torch.from_numpy(0.5*np.random.RandomState(123).randn(w_avg_samples, G.mapping.c_dim)).to(device)
|
||||
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), labels) # [N, L, C]
|
||||
w_samples = w_samples.cpu().numpy().astype(np.float32) # [N, L, C]
|
||||
w_samples_1d = w_samples[:, :1, :].astype(np.float32)
|
||||
|
||||
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, L, C]
|
||||
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
|
||||
|
||||
kmeans_latents = None
|
||||
if initial_latent is not None:
|
||||
w_avg = initial_latent
|
||||
else:
|
||||
if kmeans and use_clip and target_text is not None:
|
||||
from kmeans_pytorch import kmeans
|
||||
# data
|
||||
data_size, dims, num_clusters = w_avg_samples, G.z_dim, kmeans_clusters
|
||||
x = w_samples_1d
|
||||
x = torch.from_numpy(x)
|
||||
|
||||
# kmeans
|
||||
logprint(f'Performing kmeans clustering using {w_avg_samples} latents into {kmeans_clusters} clusters...')
|
||||
cluster_ids_x, cluster_centers = kmeans(
|
||||
X=x, num_clusters=num_clusters, distance='euclidean', device=device
|
||||
)
|
||||
#logprint(f'\nGenerating images from kmeans latents...')
|
||||
kmeans_latents = torch.tensor(cluster_centers, dtype=torch.float32, device=device, requires_grad=True)
|
||||
|
||||
# Setup noise inputs.
|
||||
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
|
||||
|
||||
# Load VGG16 feature detector.
|
||||
if use_vgg:
|
||||
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
||||
with dnnlib.util.open_url(url) as f:
|
||||
vgg16 = torch.jit.load(f).eval().to(device)
|
||||
|
||||
# Load CLIP
|
||||
if use_clip:
|
||||
model, transform = clip.load("ViT-B/32", device=device)
|
||||
|
||||
# Features for target image.
|
||||
if target_image is not None:
|
||||
target_images = target_image.unsqueeze(0).to(device).to(torch.float32)
|
||||
small_target = F.interpolate(target_images, size=(64, 64), mode='area')
|
||||
if use_center:
|
||||
center_target = F.interpolate(target_images, size=(448, 448), mode='area')[:, :, 112:336, 112:336]
|
||||
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
|
||||
target_images = target_images[:, :, 16:240, 16:240] # 256 -> 224, center crop
|
||||
|
||||
if use_vgg:
|
||||
vgg_target_features = vgg16(target_images, resize_images=False, return_lpips=True)
|
||||
if use_center:
|
||||
vgg_target_center = vgg16(center_target, resize_images=False, return_lpips=True)
|
||||
|
||||
if use_clip:
|
||||
if target_image is not None:
|
||||
with torch.no_grad():
|
||||
clip_target_features = model.encode_image(((target_images / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]).float()
|
||||
if use_center:
|
||||
clip_target_center = model.encode_image(((center_target / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]).float()
|
||||
|
||||
if kmeans_latents is not None and use_clip and target_text is not None:
|
||||
scores, kmeans_images = score_images(G, model, target_text, kmeans_latents.repeat([1, G.mapping.num_ws, 1]), device=device)
|
||||
ind = np.argpartition(scores, 4)[:4]
|
||||
w_avg = torch.median(kmeans_latents[ind],dim=0,keepdim=True)[0].repeat([1, G.mapping.num_ws, 1])
|
||||
|
||||
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
|
||||
w_avg_tensor = w_opt.clone()
|
||||
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
|
||||
optimizer = torch.optim.AdamW([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
|
||||
|
||||
# Init noise.
|
||||
for buf in noise_bufs.values():
|
||||
buf[:] = torch.randn_like(buf)
|
||||
buf.requires_grad = True
|
||||
|
||||
for step in range(num_steps):
|
||||
# Learning rate schedule.
|
||||
t = step / num_steps
|
||||
w_noise_scale = max_noise * w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
|
||||
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
|
||||
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
|
||||
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
|
||||
lr = initial_learning_rate * lr_ramp
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
# Synth images from opt_w.
|
||||
w_noise = torch.randn_like(w_opt) * w_noise_scale
|
||||
ws = w_opt + w_noise
|
||||
synth_images = G.synthesis(torch.clamp(ws,-latent_range,latent_range), noise_mode='const')
|
||||
|
||||
# Downsample image to 256x256 if it's larger than that. CLIP was built for 224x224 images.
|
||||
synth_images = (torch.clamp(synth_images, -1, 1) + 1) * (255/2)
|
||||
small_synth = F.interpolate(synth_images, size=(64, 64), mode='area')
|
||||
if use_center:
|
||||
center_synth = F.interpolate(synth_images, size=(448, 448), mode='area')[:, :, 112:336, 112:336]
|
||||
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
|
||||
|
||||
# Features for synth images.
|
||||
synth_images = synth_images[:, :, 16:240, 16:240] # 256 -> 224, center crop
|
||||
|
||||
dist = 0
|
||||
|
||||
if use_vgg:
|
||||
vgg_synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
|
||||
vgg_dist = (vgg_target_features - vgg_synth_features).square().sum()
|
||||
if use_center:
|
||||
vgg_synth_center = vgg16(center_synth, resize_images=False, return_lpips=True)
|
||||
vgg_dist += (vgg_target_center - vgg_synth_center).square().sum()
|
||||
vgg_dist *= 6
|
||||
dist += F.relu(vgg_dist*vgg_dist - min_threshold)
|
||||
|
||||
if use_clip:
|
||||
clip_synth_image = ((synth_images / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]
|
||||
clip_synth_features = model.encode_image(clip_synth_image).float()
|
||||
adj_center = 2.0
|
||||
|
||||
if use_center:
|
||||
clip_cynth_center_image = ((center_synth / 255.0) - image_mean[None, :, None, None]) / image_std[None, :, None, None]
|
||||
adj_center = 1.0
|
||||
clip_synth_center = model.encode_image(clip_cynth_center_image).float()
|
||||
|
||||
if target_image is not None:
|
||||
clip_dist = (clip_target_features - clip_synth_features).square().sum()
|
||||
if use_center:
|
||||
clip_dist += (clip_target_center - clip_synth_center).square().sum()
|
||||
dist += F.relu(0.5 + adj_center*clip_dist - min_threshold)
|
||||
|
||||
if target_text is not None:
|
||||
clip_text = 1 - model(clip_synth_image, target_text)[0].sum() / 100
|
||||
if use_center:
|
||||
clip_text += 1 - model(clip_cynth_center_image, target_text)[0].sum() / 100
|
||||
dist += 2*F.relu(adj_center*clip_text*clip_text - min_threshold / adj_center)
|
||||
|
||||
if use_pixel:
|
||||
pixel_dist = (target_images - synth_images).abs().sum() / 2000000.0
|
||||
if use_center:
|
||||
pixel_dist += (center_target - center_synth).abs().sum() / 2000000.0
|
||||
pixel_dist += (small_target - small_synth).square().sum() / 800000.0
|
||||
pixel_dist /= 4
|
||||
dist += F.relu(lr_ramp * pixel_dist - min_threshold)
|
||||
|
||||
if use_penalty:
|
||||
l1_penalty = (w_opt - w_avg_tensor).abs().sum() / 5000.0
|
||||
dist += F.relu(lr_ramp * l1_penalty - min_threshold)
|
||||
|
||||
# Noise regularization.
|
||||
reg_loss = 0.0
|
||||
for v in noise_bufs.values():
|
||||
noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
|
||||
while True:
|
||||
reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
|
||||
reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
|
||||
if noise.shape[2] <= 8:
|
||||
break
|
||||
noise = F.avg_pool2d(noise, kernel_size=2)
|
||||
#print(vgg_dist, clip_dist, pixel_dist, l1_penalty, reg_loss * regularize_noise_weight)
|
||||
loss = dist + reg_loss * regularize_noise_weight
|
||||
|
||||
# Step
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
|
||||
with torch.no_grad():
|
||||
torch.clamp(w_opt,-latent_range,latent_range,out=w_opt)
|
||||
# Save projected W for each optimization step.
|
||||
w_out[step] = w_opt.detach()[0]
|
||||
# Normalize noise.
|
||||
with torch.no_grad():
|
||||
for buf in noise_bufs.values():
|
||||
buf -= buf.mean()
|
||||
buf *= buf.square().mean().rsqrt()
|
||||
|
||||
return w_out
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@click.command()
|
||||
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
|
||||
@click.option('--target-image', 'target_fname', help='Target image file to project to', required=False, metavar='FILE', default=None)
|
||||
@click.option('--target-text', help='Target text to project to', required=False, default=None)
|
||||
@click.option('--initial-latent', help='Initial latent', default=None)
|
||||
@click.option('--lr', help='Learning rate', type=float, default=0.1, show_default=True)
|
||||
@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
|
||||
@click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
|
||||
@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
|
||||
@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
|
||||
@click.option('--use-vgg', help='Use VGG16 in the loss', type=bool, default=True, show_default=True)
|
||||
@click.option('--use-clip', help='Use CLIP in the loss', type=bool, default=True, show_default=True)
|
||||
@click.option('--use-pixel', help='Use L1/L2 distance on pixels in the loss', type=bool, default=True, show_default=True)
|
||||
@click.option('--use-penalty', help='Use a penalty on latent values distance from the mean in the loss', type=bool, default=True, show_default=True)
|
||||
@click.option('--use-center', help='Optimize against an additional center image crop', type=bool, default=True, show_default=True)
|
||||
@click.option('--use-kmeans', help='Perform kmeans clustering for selecting initial latents', type=bool, default=True, show_default=True)
|
||||
def run_projection(
|
||||
network_pkl: str,
|
||||
target_fname: str,
|
||||
target_text: str,
|
||||
initial_latent: str,
|
||||
outdir: str,
|
||||
save_video: bool,
|
||||
seed: int,
|
||||
lr: float,
|
||||
num_steps: int,
|
||||
use_vgg: bool,
|
||||
use_clip: bool,
|
||||
use_pixel: bool,
|
||||
use_penalty: bool,
|
||||
use_center: bool,
|
||||
use_kmeans: bool,
|
||||
):
|
||||
"""Project given image to the latent space of pretrained network pickle.
|
||||
|
||||
Examples:
|
||||
|
||||
\b
|
||||
python projector.py --outdir=out --target=~/mytargetimg.png \\
|
||||
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
|
||||
"""
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Load networks.
|
||||
print('Loading networks from "%s"...' % network_pkl)
|
||||
device = torch.device('cuda')
|
||||
with dnnlib.util.open_url(network_pkl) as fp:
|
||||
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
|
||||
|
||||
# Load target image.
|
||||
target_image = None
|
||||
if target_fname:
|
||||
target_pil = PIL.Image.open(target_fname).convert('RGB').filter(ImageFilter.SHARPEN)
|
||||
|
||||
w, h = target_pil.size
|
||||
s = min(w, h)
|
||||
target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
|
||||
target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
|
||||
target_uint8 = np.array(target_pil, dtype=np.uint8)
|
||||
target_image = torch.tensor(target_uint8.transpose([2, 0, 1]), device=device)
|
||||
|
||||
if target_text:
|
||||
target_text = torch.cat([clip.tokenize(target_text)]).to(device)
|
||||
|
||||
if initial_latent is not None:
|
||||
initial_latent = np.load(initial_latent)
|
||||
initial_latent = initial_latent[initial_latent.files[0]]
|
||||
|
||||
# Optimize projection.
|
||||
start_time = perf_counter()
|
||||
projected_w_steps = project(
|
||||
G,
|
||||
target_image=target_image,
|
||||
target_text=target_text,
|
||||
initial_latent=initial_latent,
|
||||
initial_learning_rate=lr,
|
||||
num_steps=num_steps,
|
||||
use_vgg=use_vgg,
|
||||
use_clip=use_clip,
|
||||
use_pixel=use_pixel,
|
||||
use_penalty=use_penalty,
|
||||
use_center=use_center,
|
||||
kmeans=use_kmeans,
|
||||
device=device,
|
||||
verbose=True
|
||||
)
|
||||
print (f'Elapsed: {(perf_counter()-start_time):.1f} s')
|
||||
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
# Save final projected frame and W vector.
|
||||
if target_fname:
|
||||
target_pil.save(f'{outdir}/target.png')
|
||||
projected_w = projected_w_steps[-1]
|
||||
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
|
||||
synth_image = (synth_image + 1) * (255/2)
|
||||
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
|
||||
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
|
||||
np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
|
||||
|
||||
# Render debug output: optional video and projected image and W vector.
|
||||
if save_video:
|
||||
video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
|
||||
print (f'Saving optimization progress video "{outdir}/proj.mp4"')
|
||||
for projected_w in projected_w_steps:
|
||||
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
|
||||
synth_image = (synth_image + 1) * (255/2)
|
||||
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
|
||||
if target_fname:
|
||||
video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
|
||||
else:
|
||||
video.append_data(synth_image)
|
||||
video.close()
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_projection() # pylint: disable=no-value-for-parameter
|
||||
|
||||
#----------------------------------------------------------------------------
|
|
@ -0,0 +1,216 @@
|
|||
import os
|
||||
import datetime
|
||||
import json
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
from enum import Enum
|
||||
import logging
|
||||
import numpy as np
|
||||
import dnnlib
|
||||
import legacy
|
||||
|
||||
logger = logging.getLogger('runs')
|
||||
|
||||
def jsonlines(filename):
|
||||
# quick n dirty way to load jsonlines file
|
||||
with open(filename, 'r') as fp:
|
||||
for line in fp:
|
||||
yield json.loads(line)
|
||||
|
||||
class Snapshot():
|
||||
def __init__(self, run, metrics):
|
||||
self.run = run
|
||||
self.metrics = metrics
|
||||
self.iteration = int(metrics["snapshot_pkl"][17:-4])
|
||||
self.iteration_str = metrics["snapshot_pkl"][17:-4]
|
||||
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return f"{self.run.as_nr}_{self.iteration_str}"
|
||||
|
||||
@property
|
||||
def fid(self):
|
||||
"""Fréchet inception distance, as calculated during training"""
|
||||
return self.metrics['results']['fid50k_full']
|
||||
|
||||
@property
|
||||
def cumulative_iteration(self):
|
||||
"""Iteration nr, taking into account the snapshot the run.resumed_from"""
|
||||
if self.run.resumed_from is None:
|
||||
return self.iteration
|
||||
return self.run.resumed_from.iteration + self.iteration
|
||||
|
||||
|
||||
@property
|
||||
def time(self):
|
||||
return datetime.datetime.fromtimestamp(int(self.metrics['timestamp']))
|
||||
|
||||
@property
|
||||
def pkl_path(self):
|
||||
return os.path.join(self.run.directory, f"network-snapshot-{self.iteration_str}.pkl")
|
||||
|
||||
def load_generator(self, device):
|
||||
with dnnlib.util.open_url(self.pkl_path) as f:
|
||||
return legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
||||
|
||||
def get_preview_img(self, cols = 1, rows = 1) -> Image:
|
||||
file = os.path.join(self.run.directory, f"fakes{self.iteration_str}.png")
|
||||
img = Image.open(file)
|
||||
return img.crop((0,0, self.run.resolution * cols, self.run.resolution * rows))
|
||||
|
||||
class Run():
|
||||
def __init__(self, directory):
|
||||
self.directory = directory
|
||||
self.id = os.path.basename(directory)
|
||||
|
||||
self.metric_path = os.path.join(self.directory, 'metric-fid50k_full.jsonl')
|
||||
self.options_path = os.path.join(self.directory, 'training_options.json')
|
||||
self.stats_path = os.path.join(self.directory, 'stats.jsonl')
|
||||
|
||||
with open (self.options_path) as fp:
|
||||
self.training_options = json.load(fp)
|
||||
|
||||
self.resumed_from = None
|
||||
if 'resume_pkl' in self.training_options:
|
||||
resume_from_dir = os.path.dirname(self.training_options['resume_pkl'])
|
||||
try:
|
||||
self.resumed_from = [
|
||||
s for s in
|
||||
Run(resume_from_dir).snapshots
|
||||
if os.path.abspath(s.pkl_path) == os.path.abspath(self.training_options['resume_pkl'])
|
||||
][0]
|
||||
except:
|
||||
logger.warning("Could not load parent snapshot")
|
||||
logger.debug()
|
||||
|
||||
|
||||
if os.path.exists(self.metric_path):
|
||||
self.snapshots = [Snapshot(self, l) for l in jsonlines(self.metric_path)]
|
||||
else:
|
||||
self.snapshots = []
|
||||
|
||||
@property
|
||||
def as_nr(self):
|
||||
return self.id[:5]
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
return self.snapshots[-1].time - self.snapshots[0].time
|
||||
|
||||
@property
|
||||
def kimg_offset(self):
|
||||
if not self.resumed_from:
|
||||
return 0
|
||||
return self.resumed_from.iteration
|
||||
|
||||
def get_stats(self):
|
||||
"""fetch stats from stats.jsonl file
|
||||
Each stats has `num` (nr. of datapoints),
|
||||
`mean` (mean of points), `std` (std dev)
|
||||
yields each line
|
||||
"""
|
||||
yield from jsonlines(self.stats_path)
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.snapshots) < 1
|
||||
|
||||
# def get_fids(self) -> dict:
|
||||
# return {:l['results']['fid50k_full'] for l in jsonlines(self.metric_path)}
|
||||
|
||||
# @property
|
||||
# def fakes(self):
|
||||
# return sorted([f for f in os.listdir(rundir) if f.startswith('fake')])
|
||||
|
||||
|
||||
@property
|
||||
def dataset_id(self):
|
||||
return list(filter(None, self.training_options["training_set_kwargs"]["path"].split(os.path.sep)))[-1]
|
||||
|
||||
|
||||
def dataset_is_conditional(self):
|
||||
return bool(self.training_options["training_set_kwargs"]["use_labels"])
|
||||
|
||||
@property
|
||||
def resolution(self):
|
||||
return self.training_options["training_set_kwargs"]["resolution"]
|
||||
@property
|
||||
def r1_gamma(self):
|
||||
return self.training_options["loss_kwargs"]["r1_gamma"]
|
||||
|
||||
def get_summary(self):
|
||||
return {
|
||||
# "name": self.id,
|
||||
"nr": self.as_nr,
|
||||
"dataset": self.dataset_id,
|
||||
"conditional": self.dataset_is_conditional(),
|
||||
"resolution": self.resolution,
|
||||
"gamma": self.r1_gamma,
|
||||
"duration": self.duration,
|
||||
# "finished": self.snapshots[-1].time,
|
||||
"iterations": self.snapshots[-1].iteration,
|
||||
"last_fid": self.snapshots[-1].fid
|
||||
}
|
||||
|
||||
def get_runs_in_dir(dir_path, include_empty = False) -> List[Run]:
|
||||
run_dirs = sorted(os.listdir(dir_path))
|
||||
runs = []
|
||||
for run_dir in run_dirs:
|
||||
run = Run(os.path.join(dir_path, run_dir))
|
||||
if include_empty or not run.is_empty():
|
||||
runs.append(run)
|
||||
return runs
|
||||
|
||||
class StreetType(Enum):
|
||||
RUE = 'Rue'
|
||||
AVENUE = 'Avenue'
|
||||
BOULEVARD = 'Boulevard'
|
||||
|
||||
class Projection():
|
||||
# TODO: add snapshot and dataset
|
||||
def __init__(self, path, identifier, arrondisement: int, street_type: StreetType):
|
||||
self.path = path
|
||||
self.id = identifier
|
||||
self.arrondisement = arrondisement
|
||||
self.street_type = street_type
|
||||
|
||||
@property
|
||||
def img_path(self):
|
||||
return os.path.join(self.path, 'proj.png')
|
||||
|
||||
|
||||
@property
|
||||
def target_img_path(self):
|
||||
return os.path.join(self.path, 'target.png')
|
||||
|
||||
|
||||
@property
|
||||
def w_path(self):
|
||||
return os.path.join(self.path, 'projected_w.npz')
|
||||
|
||||
def load_w(self):
|
||||
with np.load(self.w_path) as data:
|
||||
return data['w']
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, path):
|
||||
dirname = list(filter(None, path.split('/')))[-1]
|
||||
parts = dirname.split('-')
|
||||
arrondisement = int(parts[0])
|
||||
street_type = None
|
||||
for t in StreetType:
|
||||
if parts[1].startswith(t.value):
|
||||
street_type = t
|
||||
break
|
||||
if street_type is None:
|
||||
raise Exception(f"Unable to determine street type for {path}")
|
||||
|
||||
return cls(path, dirname, arrondisement, street_type)
|
||||
# for StreetType.
|
||||
# street_type =
|
||||
|
||||
|
||||
|
||||
def get_projections_in_dir(projection_folder) -> List[Projection]:
|
||||
projection_paths = [os.path.join(projection_folder, p) for p in os.listdir(projection_folder) if os.path.exists(os.path.join(projection_folder, p, "projected_w.npz"))]
|
||||
return [Projection.from_path(p) for p in projection_paths]
|
|
@ -0,0 +1,82 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Document</title>
|
||||
|
||||
<script src="https://unpkg.com/pagedjs/dist/paged.polyfill.js"></script>
|
||||
<link rel="stylesheet" href="style.css">
|
||||
<link href="pagedjs-interface.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
|
||||
<section id='cover'>
|
||||
<h1 class="title">This Place Does Exist</h1>
|
||||
<h2>Stylegan 3 Snapshots</h2>
|
||||
</section>
|
||||
|
||||
<section id="toc">
|
||||
<h2>Table of contents</h2>
|
||||
<img src="{{runs_graph}}" id='runs-graph'>
|
||||
|
||||
{{runs_table}}
|
||||
|
||||
|
||||
<img src="{{runs_losses_graph}}" id='runs-losses-graph'>
|
||||
</section>
|
||||
|
||||
{% for run in runs %}
|
||||
<section class="run" id="run{{run.as_nr}}">
|
||||
|
||||
<h1>{{run.as_nr}}</h1>
|
||||
<h2>{{run.id}}</h2>
|
||||
|
||||
<dl class="metadata">
|
||||
<dt>dataset</dt>
|
||||
<dd>{{run.dataset_id}}</dd>
|
||||
<dt>conditional_dataset</dt>
|
||||
<dd>{{run.dataset_is_conditional()}}</dd>
|
||||
<dt>resolution</dt>
|
||||
<dd>{{run.resolution}}</dd>
|
||||
<dt>r1 gamma</dt>
|
||||
<dd>{{run.r1_gamma}}</dd>
|
||||
<dt>duration</dt>
|
||||
<dd>{{run.duration}}</dd>
|
||||
<dt>latest_snapshots</dt>
|
||||
<dd>{{run.snapshots[-1].iteration}}</dd>
|
||||
<dt>finished at</dt>
|
||||
<dd>{{run.snapshots[-1].time}}</dd>
|
||||
<dt>last_fid</dt>
|
||||
<dd>{{run.snapshots[-1].fid}}</dd>
|
||||
{% if run.resumed_from %}
|
||||
|
||||
<dt>resumed_from</dt>
|
||||
<dd>{{run.resumed_from.run.id}} {{run.resumed_from.iteration_str}}</dd>
|
||||
|
||||
{% endif %}
|
||||
|
||||
</dl>
|
||||
|
||||
|
||||
<div class="snapshots">
|
||||
|
||||
{% for snapshot in run.snapshots %}
|
||||
<div class="snapshot">
|
||||
<span class="iteration">{{snapshot.iteration}}</span>
|
||||
<span class="time">{{snapshot.time}}</span>
|
||||
<img src="imgs/{{snapshot.id}}.jpg">
|
||||
</div>
|
||||
{% endfor %}
|
||||
|
||||
</div>
|
||||
|
||||
</section>
|
||||
{% endfor %}
|
||||
|
||||
</body>
|
||||
|
||||
</html>
|
|
@ -0,0 +1,101 @@
|
|||
:root{
|
||||
font-family: Lexend, sans-serif;
|
||||
font-size:10pt;
|
||||
}
|
||||
|
||||
@page {
|
||||
size: A4;
|
||||
margin-top: 10mm;
|
||||
margin-right: 20mm;
|
||||
margin-bottom: 25mm;
|
||||
margin-left: 15mm;
|
||||
|
||||
|
||||
@bottom-center {
|
||||
content: string(title, first-except);
|
||||
/* text-transform: uppercase; */
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@page :left{
|
||||
|
||||
@bottom-left {
|
||||
content: counter(page);
|
||||
}
|
||||
}
|
||||
|
||||
@page :right{
|
||||
|
||||
@bottom-right {
|
||||
content: counter(page);
|
||||
}
|
||||
}
|
||||
|
||||
@page :first{
|
||||
@bottom-right{
|
||||
content:none;
|
||||
}
|
||||
}
|
||||
|
||||
img{
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
@media print {
|
||||
|
||||
|
||||
@page: left {
|
||||
margin-left: 35mm;
|
||||
margin-right: 15mm;
|
||||
}
|
||||
|
||||
@page: right {
|
||||
margin-left: 15mm;
|
||||
margin-right: 35mm;
|
||||
}
|
||||
|
||||
/* all your book chapters in <section> elements...
|
||||
you want your chapter to always start on the right page. */
|
||||
section {
|
||||
break-before: right;
|
||||
}
|
||||
|
||||
h1 {
|
||||
string-set: title content(text);
|
||||
font-size: 8em;
|
||||
margin-bottom:20px;
|
||||
font-family: "Lexend Zetta";
|
||||
}
|
||||
.cover h1{
|
||||
font-size: 4em;
|
||||
}
|
||||
|
||||
h2{
|
||||
font-weight:normal;
|
||||
font-size:1.5em;
|
||||
}
|
||||
|
||||
|
||||
|
||||
.tocitem::after{
|
||||
content: target-counter(attr(data-ref), page) ;
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
.snapshot{
|
||||
page-break-inside: avoid;
|
||||
}
|
||||
|
||||
.snapshot .iteration{
|
||||
float:right;
|
||||
}
|
||||
}
|
||||
|
||||
dt{
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
table {
|
||||
margin: 40px 0;
|
||||
}
|
|
@ -0,0 +1,180 @@
|
|||
/* CSS for Paged.js interface – v0.4 */
|
||||
|
||||
/* Change the look */
|
||||
:root {
|
||||
--color-background: whitesmoke;
|
||||
--color-pageSheet: #cfcfcf;
|
||||
--color-pageBox: violet;
|
||||
--color-paper: white;
|
||||
--color-marginBox: transparent;
|
||||
--pagedjs-crop-color: black;
|
||||
--pagedjs-crop-shadow: white;
|
||||
--pagedjs-crop-stroke: 1px;
|
||||
}
|
||||
|
||||
/* To define how the book look on the screen: */
|
||||
@media screen, pagedjs-ignore {
|
||||
body {
|
||||
background-color: var(--color-background);
|
||||
}
|
||||
|
||||
.pagedjs_pages {
|
||||
display: flex;
|
||||
width: calc(var(--pagedjs-width) * 2);
|
||||
flex: 0;
|
||||
flex-wrap: wrap;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
.pagedjs_page {
|
||||
background-color: var(--color-paper);
|
||||
box-shadow: 0 0 0 1px var(--color-pageSheet);
|
||||
margin: 0;
|
||||
flex-shrink: 0;
|
||||
flex-grow: 0;
|
||||
margin-top: 10mm;
|
||||
}
|
||||
|
||||
.pagedjs_first_page {
|
||||
margin-left: var(--pagedjs-width);
|
||||
}
|
||||
|
||||
.pagedjs_page:last-of-type {
|
||||
margin-bottom: 10mm;
|
||||
}
|
||||
|
||||
.pagedjs_pagebox{
|
||||
box-shadow: 0 0 0 1px var(--color-pageBox);
|
||||
}
|
||||
|
||||
.pagedjs_left_page{
|
||||
z-index: 20;
|
||||
width: calc(var(--pagedjs-bleed-left) + var(--pagedjs-pagebox-width))!important;
|
||||
}
|
||||
|
||||
.pagedjs_left_page .pagedjs_bleed-right .pagedjs_marks-crop {
|
||||
border-color: transparent;
|
||||
}
|
||||
|
||||
.pagedjs_left_page .pagedjs_bleed-right .pagedjs_marks-middle{
|
||||
width: 0;
|
||||
}
|
||||
|
||||
.pagedjs_right_page{
|
||||
z-index: 10;
|
||||
position: relative;
|
||||
left: calc(var(--pagedjs-bleed-left)*-1);
|
||||
}
|
||||
|
||||
/* show the margin-box */
|
||||
|
||||
.pagedjs_margin-top-left-corner-holder,
|
||||
.pagedjs_margin-top,
|
||||
.pagedjs_margin-top-left,
|
||||
.pagedjs_margin-top-center,
|
||||
.pagedjs_margin-top-right,
|
||||
.pagedjs_margin-top-right-corner-holder,
|
||||
.pagedjs_margin-bottom-left-corner-holder,
|
||||
.pagedjs_margin-bottom,
|
||||
.pagedjs_margin-bottom-left,
|
||||
.pagedjs_margin-bottom-center,
|
||||
.pagedjs_margin-bottom-right,
|
||||
.pagedjs_margin-bottom-right-corner-holder,
|
||||
.pagedjs_margin-right,
|
||||
.pagedjs_margin-right-top,
|
||||
.pagedjs_margin-right-middle,
|
||||
.pagedjs_margin-right-bottom,
|
||||
.pagedjs_margin-left,
|
||||
.pagedjs_margin-left-top,
|
||||
.pagedjs_margin-left-middle,
|
||||
.pagedjs_margin-left-bottom {
|
||||
box-shadow: 0 0 0 1px inset var(--color-marginBox);
|
||||
}
|
||||
|
||||
/* uncomment this part for recto/verso book : ------------------------------------ */
|
||||
|
||||
/*
|
||||
.pagedjs_pages {
|
||||
flex-direction: column;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.pagedjs_first_page {
|
||||
margin-left: 0;
|
||||
}
|
||||
|
||||
.pagedjs_page {
|
||||
margin: 0 auto;
|
||||
margin-top: 10mm;
|
||||
}
|
||||
|
||||
.pagedjs_left_page{
|
||||
width: calc(var(--pagedjs-bleed-left) + var(--pagedjs-pagebox-width) + var(--pagedjs-bleed-left))!important;
|
||||
}
|
||||
|
||||
.pagedjs_left_page .pagedjs_bleed-right .pagedjs_marks-crop{
|
||||
border-color: var(--pagedjs-crop-color);
|
||||
}
|
||||
|
||||
.pagedjs_left_page .pagedjs_bleed-right .pagedjs_marks-middle{
|
||||
width: var(--pagedjs-cross-size)!important;
|
||||
}
|
||||
|
||||
.pagedjs_right_page{
|
||||
left: 0;
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
|
||||
/*--------------------------------------------------------------------------------------*/
|
||||
|
||||
|
||||
|
||||
/* uncomment this par to see the baseline : -------------------------------------------*/
|
||||
|
||||
|
||||
/* .pagedjs_pagebox {
|
||||
--pagedjs-baseline: 22px;
|
||||
--pagedjs-baseline-position: 5px;
|
||||
--pagedjs-baseline-color: cyan;
|
||||
background: linear-gradient(transparent 0%, transparent calc(var(--pagedjs-baseline) - 1px), var(--pagedjs-baseline-color) calc(var(--pagedjs-baseline) - 1px), var(--pagedjs-baseline-color) var(--pagedjs-baseline)), transparent;
|
||||
background-size: 100% var(--pagedjs-baseline);
|
||||
background-repeat: repeat-y;
|
||||
background-position-y: var(--pagedjs-baseline-position);
|
||||
} */
|
||||
|
||||
|
||||
/*--------------------------------------------------------------------------------------*/
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/* Marks (to delete when merge in paged.js) */
|
||||
|
||||
.pagedjs_marks-crop{
|
||||
z-index: 999999999999;
|
||||
|
||||
}
|
||||
|
||||
.pagedjs_bleed-top .pagedjs_marks-crop,
|
||||
.pagedjs_bleed-bottom .pagedjs_marks-crop{
|
||||
box-shadow: 1px 0px 0px 0px var(--pagedjs-crop-shadow);
|
||||
}
|
||||
|
||||
.pagedjs_bleed-top .pagedjs_marks-crop:last-child,
|
||||
.pagedjs_bleed-bottom .pagedjs_marks-crop:last-child{
|
||||
box-shadow: -1px 0px 0px 0px var(--pagedjs-crop-shadow);
|
||||
}
|
||||
|
||||
.pagedjs_bleed-left .pagedjs_marks-crop,
|
||||
.pagedjs_bleed-right .pagedjs_marks-crop{
|
||||
box-shadow: 0px 1px 0px 0px var(--pagedjs-crop-shadow);
|
||||
}
|
||||
|
||||
.pagedjs_bleed-left .pagedjs_marks-crop:last-child,
|
||||
.pagedjs_bleed-right .pagedjs_marks-crop:last-child{
|
||||
box-shadow: 0px -1px 0px 0px var(--pagedjs-crop-shadow);
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Document</title>
|
||||
|
||||
<script src="https://unpkg.com/pagedjs/dist/paged.polyfill.js"></script>
|
||||
<link rel="stylesheet" href="style.css">
|
||||
<link href="pagedjs-interface.css" rel="stylesheet" type="text/css" />
|
||||
</head>
|
||||
|
||||
<body>
|
||||
|
||||
<section id='cover'>
|
||||
<h1 class="title">This Place Does Exist</h1>
|
||||
<h2>Stylegan 3 Snapshots</h2>
|
||||
</section>
|
||||
|
||||
<section id="toc">
|
||||
<h2>Table of contents</h2>
|
||||
<img src="{{runs_graph}}" id='runs-graph'>
|
||||
|
||||
{{runs_table}}
|
||||
|
||||
|
||||
<img src="{{runs_losses_graph}}" id='runs-losses-graph'>
|
||||
</section>
|
||||
|
||||
{% for run in runs %}
|
||||
<section class="run" id="run{{run.as_nr}}">
|
||||
|
||||
<h1>{{run.as_nr}}</h1>
|
||||
<h2>{{run.id}}</h2>
|
||||
|
||||
<dl class="metadata">
|
||||
<dt>dataset</dt>
|
||||
<dd>{{run.dataset_id}}</dd>
|
||||
<dt>conditional_dataset</dt>
|
||||
<dd>{{run.dataset_is_conditional()}}</dd>
|
||||
<dt>resolution</dt>
|
||||
<dd>{{run.resolution}}</dd>
|
||||
<dt>r1 gamma</dt>
|
||||
<dd>{{run.r1_gamma}}</dd>
|
||||
<dt>duration</dt>
|
||||
<dd>{{run.duration}}</dd>
|
||||
<dt>latest_snapshots</dt>
|
||||
<dd>{{run.snapshots[-1].iteration}}</dd>
|
||||
<dt>finished at</dt>
|
||||
<dd>{{run.snapshots[-1].time}}</dd>
|
||||
<dt>last_fid</dt>
|
||||
<dd>{{run.snapshots[-1].fid}}</dd>
|
||||
{% if run.resumed_from %}
|
||||
|
||||
<dt>resumed_from</dt>
|
||||
<dd>{{run.resumed_from.run.id}} {{run.resumed_from.iteration_str}}</dd>
|
||||
|
||||
{% endif %}
|
||||
|
||||
</dl>
|
||||
|
||||
|
||||
<div class="snapshots">
|
||||
|
||||
{% for snapshot in run.snapshots %}
|
||||
<div class="snapshot">
|
||||
<span class="iteration">{{snapshot.iteration}}</span>
|
||||
<span class="time">{{snapshot.time}}</span>
|
||||
<img src="imgs/{{snapshot.id}}.jpg">
|
||||
</div>
|
||||
{% endfor %}
|
||||
|
||||
</div>
|
||||
|
||||
</section>
|
||||
{% endfor %}
|
||||
|
||||
</body>
|
||||
|
||||
</html>
|
|
@ -0,0 +1,101 @@
|
|||
:root{
|
||||
font-family: Lexend, sans-serif;
|
||||
font-size:10pt;
|
||||
}
|
||||
|
||||
@page {
|
||||
size: A4;
|
||||
margin-top: 10mm;
|
||||
margin-right: 20mm;
|
||||
margin-bottom: 25mm;
|
||||
margin-left: 15mm;
|
||||
|
||||
|
||||
@bottom-center {
|
||||
content: string(title, first-except);
|
||||
/* text-transform: uppercase; */
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@page :left{
|
||||
|
||||
@bottom-left {
|
||||
content: counter(page);
|
||||
}
|
||||
}
|
||||
|
||||
@page :right{
|
||||
|
||||
@bottom-right {
|
||||
content: counter(page);
|
||||
}
|
||||
}
|
||||
|
||||
@page :first{
|
||||
@bottom-right{
|
||||
content:none;
|
||||
}
|
||||
}
|
||||
|
||||
img{
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
@media print {
|
||||
|
||||
|
||||
@page: left {
|
||||
margin-left: 35mm;
|
||||
margin-right: 15mm;
|
||||
}
|
||||
|
||||
@page: right {
|
||||
margin-left: 15mm;
|
||||
margin-right: 35mm;
|
||||
}
|
||||
|
||||
/* all your book chapters in <section> elements...
|
||||
you want your chapter to always start on the right page. */
|
||||
section {
|
||||
break-before: right;
|
||||
}
|
||||
|
||||
h1 {
|
||||
string-set: title content(text);
|
||||
font-size: 8em;
|
||||
margin-bottom:20px;
|
||||
font-family: "Lexend Zetta";
|
||||
}
|
||||
.cover h1{
|
||||
font-size: 4em;
|
||||
}
|
||||
|
||||
h2{
|
||||
font-weight:normal;
|
||||
font-size:1.5em;
|
||||
}
|
||||
|
||||
|
||||
|
||||
.tocitem::after{
|
||||
content: target-counter(attr(data-ref), page) ;
|
||||
margin-left: 10px;
|
||||
}
|
||||
|
||||
.snapshot{
|
||||
page-break-inside: avoid;
|
||||
}
|
||||
|
||||
.snapshot .iteration{
|
||||
float:right;
|
||||
}
|
||||
}
|
||||
|
||||
dt{
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
table {
|
||||
margin: 40px 0;
|
||||
}
|
|
@ -49,6 +49,7 @@ class Dataset(torch.utils.data.Dataset):
|
|||
if xflip:
|
||||
self._raw_idx = np.tile(self._raw_idx, 2)
|
||||
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
|
||||
# TODO, perform a similar trick, but then with random crops etc.
|
||||
|
||||
def _get_raw_labels(self):
|
||||
if self._raw_labels is None:
|
||||
|
|
Loading…
Reference in New Issue