stylegan3/ThisPlaceDoesExist.ipynb

587 lines
20 KiB
Plaintext
Raw Permalink Normal View History

2022-11-25 19:54:23 +01:00
{
"cells": [
{
"cell_type": "markdown",
"id": "c08a9c8e-bcce-4b45-94ad-b422ad60bea9",
"metadata": {},
"source": [
"# This Place Does Exists - Utilities for Stylegan3\n",
"\n",
"This notebook contains utility functions for working with the models created by StyleGAN3. \n",
"\n",
"## Usage\n",
"\n",
"Include it in any notebook using `%run ThisPlaceDoesExist.ipynb`. After which everything from this notebook becomes available. Including a `runs` variable which is a list containing all the `Run` objects."
]
},
2022-11-25 19:54:23 +01:00
{
"cell_type": "code",
"execution_count": 1,
2022-11-25 19:54:23 +01:00
"id": "f4ee99c4-9c28-4fe4-9408-e130a0d446d3",
"metadata": {},
"outputs": [],
"source": [
"from runs import Run, Snapshot, get_projections_in_dir, get_runs_in_dir\n",
"from scipy.ndimage.filters import uniform_filter1d\n",
"import cv2\n",
"from PIL import Image\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d362cae7-c5c9-4127-b221-92c859aa1620",
"metadata": {},
"outputs": [],
"source": [
"def is_main():\n",
" \"\"\"Return True if this notebook is being run by calling\n",
" %run in another notebook, False otherwise.\n",
" works around jupyter bug: https://github.com/ipython/ipython/issues/10967\n",
" \"\"\"\n",
" try:\n",
" __file__\n",
" # __file__ has been defined, so this notebook is \n",
" # being run in a parent notebook\n",
" return True\n",
"\n",
" except NameError:\n",
" # __file__ has not been defined, so this notebook is \n",
" # not being run in a parent notebook\n",
" return False\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2544f23f-3b2c-4d29-95a6-031cecab1e08",
"metadata": {},
"outputs": [],
"source": [
"args = {\n",
" 'runs_dir': 'training-runs', \n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 4,
2022-11-25 19:54:23 +01:00
"id": "4c428611-8d75-4f9a-ae5a-7960e7b01470",
"metadata": {},
"outputs": [],
"source": [
"runs = get_runs_in_dir(args['runs_dir'])\n"
]
},
{
"cell_type": "markdown",
"id": "7f77f181-059a-4568-8088-1eb86a1a172a",
"metadata": {},
"source": [
"See [Snapshot_images.ipynb](Snapshot_images.ipynb) for examples of each run/snapshot."
]
},
{
"cell_type": "markdown",
"id": "c3aa8404-aeb0-4f63-b4b4-c278f0cf3766",
"metadata": {},
"source": [
"## Plot run metrics\n",
"\n",
"We can plot the progress of the metrics (fid) for each run. Sommige runs zijn een vervolg op een eerder run. Dit zou kunnen zijn om het netwerk een voorgetraind startpunt te geven, maar in dit geval was het vooral omdat de training zo nu en dan was gestopt en weer herstart (wat een 'nieuwe' run geeft."
]
},
{
"cell_type": "code",
"execution_count": 5,
2022-11-25 19:54:23 +01:00
"id": "c180db45-9bf1-4f0a-ab55-62b476a5897c",
"metadata": {},
"outputs": [],
"source": [
"# importing package\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_runs(runs, dpi=300, palette=None):\n",
" \n",
" plt.figure(dpi=dpi)\n",
" plt.yscale('log')\n",
" for i, run in enumerate(runs):\n",
" x = [s.cumulative_iteration for s in run.snapshots]\n",
" y = [s.fid for s in run.snapshots]\n",
" # plot lines\n",
" c = palette[i%len(palette)] if palette else None\n",
" plt.plot(x, y, label = f\"{i} {run.id}\", c=c)\n",
"\n",
" \n",
" return plt\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
2022-11-25 19:54:23 +01:00
"id": "22c74f41-65e7-461a-9094-f0b3d8738c82",
"metadata": {},
"outputs": [],
2022-11-25 19:54:23 +01:00
"source": [
"def is_main():\n",
" plot = plot_runs(runs)\n",
" plot.legend(bbox_to_anchor=(1,0), loc=\"lower left\")\n",
" plot.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b28462dc-63a7-4630-8cf2-44eb2869661a",
"metadata": {},
"outputs": [],
"source": [
"def plot_stats(stat_ids, runs, dpi=300, palette=None): \n",
" fig2, axes = plt.subplots(nrows=1, ncols=len(stat_ids), figsize=(10*len(stat_ids), 9), dpi=dpi)\n",
"\n",
" for i, ax in enumerate(axes):\n",
" ax.set_xlabel('kimg')\n",
" ax.set_ylabel(stat_ids[i])\n",
" ax.set_yscale('symlog', linthresh=1) # 0-1: linear, >1: log scale\n",
"\n",
" for i, run in enumerate(runs):\n",
" stats = [\n",
" [\n",
" s['Progress/kimg']['mean'] + run.kimg_offset\n",
" ] + [[s[sid]['mean'], s[sid]['std']] for sid in stat_ids]\n",
" for s in run.get_stats()\n",
" ]\n",
" x = [ s[0] for s in stats ]\n",
" \n",
" c = palette[i%len(palette)] if palette else None\n",
" # smooth slightly for better readability\n",
" for i, stat_id in enumerate(stat_ids):\n",
" error = [s[i+1][1] for s in stats]\n",
" y = uniform_filter1d([s[i+1][0] for s in stats], size=20)\n",
" axes[i].plot(x, y, label = f\"{i} {run.id}\", c=c)\n",
" # draw std dev:\n",
" # axes[i].fill_between(x, y-error, y+error,\n",
" # alpha=0.2,\n",
" # antialiased=True)\n",
"\n",
"\n",
" # print(x,y)\n",
" # x = [s.cumulative_iteration for s in stats]\n",
" # y = [s.fid for s in run.stats]\n",
" # # plot lines\n",
" # ax2.plot(x, y2, label = f\"{i} {run.id}\")\n",
" return plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8bb5a2db-4e04-4402-953c-262a4a010c96",
"metadata": {},
"outputs": [],
"source": [
"def is_main():\n",
" plot = plot_stats([\n",
" 'Loss/D/loss',\n",
" 'Loss/G/loss',\n",
" ], runs)\n",
" plot.legend()\n",
" plot.show()"
]
},
{
"cell_type": "markdown",
"id": "9c4d937f-0035-4760-a7de-93080dfe5438",
"metadata": {},
"source": [
"# excerpts\n",
"\n",
"Als we de runs met de laagste FID scores bekijken bekijken krijgen we een beeld van de kwaliteit van de netwerken.\n",
"\n",
"Wat vooral opvalt is dat het netwerk met gecropte beelden (00014+16) de meest kleurrijke beelden geeft en dus niet convergeert naar een beige-grijs zoals veel van de andere netwerken."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "461480c4-d4fe-4fc8-a7fa-6e66d468595a",
"metadata": {},
"outputs": [],
"source": [
"def is_main():\n",
" display(\n",
" runs[-1].snapshots[-10].iteration,\n",
" runs[-1].snapshots[-10].get_preview_img(8,1),\n",
" runs[-1].snapshots[-2].iteration,\n",
" runs[-1].snapshots[-2].get_preview_img(8,1),\n",
" runs[-1].snapshots[-1].iteration,\n",
" runs[-1].snapshots[-1].get_preview_img(8,1),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9782b289-101d-4bc4-b9e4-8d0838dec09b",
"metadata": {},
"outputs": [],
"source": [
"def is_main():\n",
" display(\n",
" runs[3].snapshots[-1].get_preview_img(4,1),\n",
" runs[2].snapshots[-1].get_preview_img(4,1),\n",
" runs[5].snapshots[-1].get_preview_img(4,1)\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "22668f34-fa75-4d05-abeb-c4b41bec8093",
"metadata": {},
"source": [
"# Stylegan 3 functions"
]
},
{
"cell_type": "markdown",
"id": "de9c301f-5a20-4eb7-87c1-83c8e67508f0",
"metadata": {},
"source": [
"Helper functions for Stylegan 3 operations"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b31b0b4b-7fde-4a8b-812e-5545c68ce64d",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "markdown",
"id": "ab78ac4c-5e2d-456c-9afd-9639ad963b51",
"metadata": {},
"source": [
"Converts seeds to `z` space, `z`-space to `w`-space and use `w`-space to generate images and generated images to jupyter-widgets for visualisation"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "bb0593f9-05df-47e7-9555-0af43951a49b",
"metadata": {},
"outputs": [],
"source": [
"# adapted from https://github.com/dvschultz/stylegan2-ada-pytorch/blob/9b6750b96dc9841816e8ac57b05f395d0f23c30d/generate.py\n",
"\n",
"def seeds_to_zs(G,seeds):\n",
" zs = []\n",
" for seed_idx, seed in enumerate(seeds):\n",
" z = np.random.RandomState(seed).randn(1, G.z_dim)\n",
" zs.append(z)\n",
" return zs\n",
"\n",
"def zs_to_ws(G,device,label,truncation_psi,zs):\n",
" ws = []\n",
" for z in zs:\n",
" z = torch.from_numpy(z).to(device)\n",
" w = G.mapping(z, label, truncation_psi=truncation_psi, truncation_cutoff=8)\n",
" ws.append(w)\n",
" return ws\n",
"\n",
"def images(G,device,inputs,space,truncation_psi,label,noise_mode,start=None,stop=None):\n",
" \"\"\"Generate image for z or w space image (deprecated)\"\"\"\n",
" if(start is not None and stop is not None):\n",
" tp = start\n",
" tp_i = (stop-start)/len(inputs)\n",
"\n",
" for idx, i in enumerate(inputs):\n",
" # print('Generating image for frame %d/%d ...' % (idx, len(inputs)))\n",
" \n",
" if (space=='z'):\n",
" z = torch.from_numpy(i).to(device)\n",
" if(start is not None and stop is not None):\n",
" img = G(z, label, truncation_psi=tp, noise_mode=noise_mode)\n",
" tp = tp+tp_i\n",
" else:\n",
" img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)\n",
" else:\n",
" if len(i.shape) == 2: \n",
" i = torch.from_numpy(i).unsqueeze(0).to(device)\n",
" img = G.synthesis(i, noise_mode=noise_mode, force_fp32=True)\n",
" img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)\n",
" yield f\"{idx:04d}\", Image.fromarray(img[0].cpu().numpy(), 'RGB')\n",
"\n",
"def w_to_img(G, device, noise_mode, w):\n",
" img = G.synthesis(w, noise_mode=noise_mode, force_fp32=True)\n",
" img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)\n",
" return Image.fromarray(img[0].cpu().numpy(), 'RGB')\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "efafbb3a-7d2c-4b6a-a626-f06dd1b2b1ec",
"metadata": {},
"source": [
2022-11-26 20:40:15 +01:00
"## Interpolation of vectors; browsing latent space."
2022-11-25 19:54:23 +01:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "014449b1-1293-49bb-90bf-a1949315183c",
"metadata": {},
"outputs": [],
"source": [
"def interpolate(G,device,projected_w,seeds,random_seed,space,truncation_psi,label,frames,noise_mode,outdir,interpolation,\n",
" easing, diameter, start=None,stop=None):\n",
" if(interpolation=='noiseloop' or interpolation=='circularloop'):\n",
" if seeds is not None:\n",
" print(f'Warning: interpolation type: \"{interpolation}\" doesnt support set seeds.')\n",
"\n",
" if(interpolation=='noiseloop'):\n",
" points = noiseloop(frames, diameter, random_seed)\n",
" elif(interpolation=='circularloop'):\n",
" points = circularloop(frames, diameter, random_seed, seeds)\n",
"\n",
" else:\n",
" if projected_w is not None:\n",
" points = np.load(projected_w)['w']\n",
" else:\n",
" # get zs from seeds\n",
" points = seeds_to_zs(G,seeds) \n",
" # convert to ws\n",
" if(space=='w'):\n",
" points = zs_to_ws(G,device,label,truncation_psi,points)\n",
"\n",
" # get interpolation points\n",
" if(interpolation=='linear'):\n",
" points = line_interpolate(points,frames,easing)\n",
" elif(interpolation=='slerp'):\n",
" points = slerp_interpolate(points,frames)\n",
" \n",
" # generate frames\n",
" for idx, img in images(G,device,points,space,truncation_psi,label,noise_mode,outdir,start,stop):\n",
" yield idx, img\n",
"\n",
"# slightly modified version of\n",
"# https://github.com/PDillis/stylegan2-fun/blob/master/run_generator.py#L399\n",
"def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):\n",
" '''\n",
" Spherical linear interpolation\n",
" Args:\n",
" t (float/np.ndarray): Float value between 0.0 and 1.0\n",
" v0 (np.ndarray): Starting vector\n",
" v1 (np.ndarray): Final vector\n",
" DOT_THRESHOLD (float): Threshold for considering the two vectors as\n",
" colineal. Not recommended to alter this.\n",
" Returns:\n",
" v2 (np.ndarray): Interpolation vector between v0 and v1\n",
" '''\n",
" v0 = v0.cpu().detach().numpy() if hasattr(v0, 'cpu') else v0\n",
" v1 = v1.cpu().detach().numpy() if hasattr(v1, 'cpu') else v1\n",
" # Copy the vectors to reuse them later\n",
" v0_copy = np.copy(v0)\n",
" v1_copy = np.copy(v1)\n",
" # Normalize the vectors to get the directions and angles\n",
" v0 = v0 / np.linalg.norm(v0)\n",
" v1 = v1 / np.linalg.norm(v1)\n",
" # Dot product with the normalized vectors (can't use np.dot in W)\n",
" dot = np.sum(v0 * v1)\n",
" # If absolute value of dot product is almost 1, vectors are ~colineal, so use lerp\n",
" if np.abs(dot) > DOT_THRESHOLD:\n",
" return lerp(t, v0_copy, v1_copy)\n",
" # Calculate initial angle between v0 and v1\n",
" theta_0 = np.arccos(dot)\n",
" sin_theta_0 = np.sin(theta_0)\n",
" # Angle at timestep t\n",
" theta_t = theta_0 * t\n",
" sin_theta_t = np.sin(theta_t)\n",
" # Finish the slerp algorithm\n",
" s0 = np.sin(theta_0 - theta_t) / sin_theta_0\n",
" s1 = sin_theta_t / sin_theta_0\n",
" v2 = s0 * v0_copy + s1 * v1_copy\n",
" return torch.from_numpy(v2).to(\"cuda\")\n",
"\n",
"def slerp_interpolate(zs, steps):\n",
" out = []\n",
" for i in range(len(zs)-1):\n",
" for index in range(steps):\n",
" fraction = index/float(steps)\n",
" out.append(slerp(fraction,zs[i],zs[i+1]))\n",
" return out\n"
]
},
{
"cell_type": "markdown",
"id": "60be2986-914d-4032-8554-cdd3ab52f827",
"metadata": {},
"source": [
"## Project an imag to the latent space"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f802546c-6e74-440f-8d88-9e5a166b319d",
"metadata": {},
"outputs": [],
"source": [
"import subprocess, operator\n",
"\n",
"def project_img_to_z(snapshot: Snapshot, image_filename: str, steps=1000, replace_if_exists=False) -> dict:\n",
" \n",
" # imagenr = image_filename[-12:-4]\n",
" image_name = image_filename[:-4]\n",
" runnr = snapshot.run.as_nr\n",
" # !!python pbaylies_projector.py --network $snapshot_pkl --outdir out/projections/$runnr-$imagenr --target-image $image_filename --use-clip=False\n",
" \n",
" if replace_if_exists or not os.path.exists(f\"out/projections/{snapshot.id}/{image_name}/proj.png\"):\n",
2022-11-25 19:54:23 +01:00
" process = subprocess.Popen([\n",
" \"python\", \"pbaylies_projector.py\",\n",
" \"--network\" , snapshot.pkl_path,\n",
" \"--outdir\", f\"out/projections/{snapshot.id}/{image_name}\",\n",
2022-11-25 19:54:23 +01:00
" \"--target-image\", image_filename,\n",
" \"--use-clip\", \"False\",\n",
" \"--num-steps\", str(steps),\n",
" \"--save-video\", \"False\"\n",
" ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)\n",
" stdout, stderr = process.communicate()\n",
" lines = stdout.split(\"\\n\")\n",
" lossess_and_distances = [operator.itemgetter(-1,-3)(line.split(\" \")) for line in lines if line.startswith(\"step\")]\n",
" print(stderr)\n",
" loss, dist = lossess_and_distances[-1]\n",
" else:\n",
" # TODO: get loss and dist from somewhere? (currently not using it much)\n",
" loss, dist = (None, None)\n",
"\n",
" return {\n",
" \"img\": f\"out/projections/{snapshot.id}/{image_name}/proj.png\",\n",
" \"src_img\": f\"out/projections/{snapshot.id}/{image_name}/target.png\",\n",
2022-11-25 19:54:23 +01:00
" \"src\": image_filename,\n",
" \"npz\": f\"out/projections/{snapshot.id}/{image_name}/projected_w.npz\",\n",
2022-11-25 19:54:23 +01:00
" \"loss\": loss,\n",
" \"dist\": dist\n",
" }\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "6f10a7eb-5056-41da-9cfd-894a1c772e0d",
"metadata": {},
"source": [
"# Displaying Videos and Images"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "f7ccb1c5-13c3-4ff9-9030-125622ccf93a",
"metadata": {},
"outputs": [],
"source": [
"def img_to_widget(img):\n",
" buff = io.BytesIO()\n",
" img.save(buff, format='png')\n",
" \n",
" return widgets.Image(value=buff.getvalue(), format='png', width=run_data['resolution'], height=run_data['resolution'])\n",
"\n",
"def video_to_widget(filename):\n",
" with open(filename, 'rb') as fp:\n",
" video = fp.read()\n",
" return widgets.Video(value=video)\n",
"\n",
"def image_grid(imgs, cols=None, rows=None, margin = 10):\n",
"# create image grid, if no size is given, put all on horizontal axis\n",
" if cols is None or rows is None:\n",
" cols = len(imgs)\n",
" rows = 1\n",
" \n",
" w, h = imgs[0].size\n",
" w, h = w+margin, h+margin\n",
" grid = Image.new('RGB', size=(cols*w-margin, rows*h-margin))\n",
" grid_w, grid_h = grid.size\n",
" \n",
" for i, img in enumerate(imgs):\n",
" grid.paste(img, box=(i%cols*w, i//cols*h))\n",
" return grid"
]
},
{
"cell_type": "markdown",
"id": "c988053f-4a6d-4f66-a40d-593bf5a35765",
"metadata": {},
"source": [
"# Postprocessing"
]
},
{
"cell_type": "markdown",
"id": "573284e5-1c27-415c-9876-4a4a7e6fc704",
"metadata": {},
"source": [
"Sometimes we like to generate videos from a series of ws"
]
},
{
"cell_type": "code",
2022-11-26 20:40:15 +01:00
"execution_count": 40,
2022-11-25 19:54:23 +01:00
"id": "7fd5f974-3f44-4f92-b1b8-7564400229fa",
"metadata": {},
"outputs": [],
"source": [
"from imageio_ffmpeg import write_frames\n",
"\n",
"def generator_to_video(generator, out_filename, fps, frame_size, quality):\n",
" writer = write_frames(out_filename, frame_size, quality=quality) # size is (width, height)\n",
" writer.send(None) # seed the generator\n",
2022-11-26 20:40:15 +01:00
" # print(os.path.abspath(out_filename))\n",
2022-11-25 19:54:23 +01:00
" # output = cv2.VideoWriter(\n",
" # out_filename,\n",
" # # see http://mp4ra.org/#/codecs for codecs\n",
" # cv2.VideoWriter_fourcc(*'vp09'),\n",
" # fps,\n",
" # frame_size)\n",
" for frame in generator:\n",
" if type(frame) is Image.Image:\n",
" open_cv_image = np.array(frame) \n",
2022-11-26 20:40:15 +01:00
" frame = open_cv_image\n",
" # Convert RGB->BGR (for openCV\n",
" # frame = open_cv_image[:, :, ::-1].copy()\n",
2022-11-25 19:54:23 +01:00
" # output.write(frame)\n",
" writer.send(frame)\n",
" \n",
" # output.release()\n",
" writer.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "paris-stylegan3",
"language": "python",
"name": "paris-stylegan3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}