{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "tDybPQiEFQuJ" }, "source": [ "In this notebook, we will show how to load pre-trained models and draw things with sketch-rnn" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "k0GqvYgB9JLC" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" ] } ], "source": [ "# import the required libraries\n", "import numpy as np\n", "import time\n", "import random\n", "import pickle\n", "import codecs\n", "import collections\n", "import os\n", "import math\n", "import json\n", "import tensorflow as tf\n", "from six.moves import xrange" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "UI4ZC__4FQuL" }, "outputs": [], "source": [ "# libraries required for visualisation:\n", "from IPython.display import SVG, display\n", "import PIL\n", "from PIL import Image\n", "import matplotlib.pyplot as plt\n", "\n", "# set numpy output to something sensible\n", "np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "D7ObpAUh9jrk" }, "outputs": [], "source": [ "# !pip install -qU svgwrite" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "4xYY-TUd9aiD" }, "outputs": [], "source": [ "import svgwrite # conda install -c omnia svgwrite=1.1.6" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "NzPSD-XRFQuP", "outputId": "daa0dd33-6d59-4d15-f437-d8ec787c8884" }, "outputs": [], "source": [ "tf.logging.info(\"TensorFlow Version: %s\", tf.__version__)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", "id": "LebxcF4p90OR" }, "outputs": [], "source": [ "# !pip install -q magenta" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "NkFS0E1zFQuU" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: Logging before flag parsing goes to stderr.\n", "W0825 15:58:50.188074 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/pipelines/statistics.py:132: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", "\n", "/home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/numba/errors.py:131: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9\n", " warnings.warn(msg)\n", "W0825 15:58:50.811670 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/music/note_sequence_io.py:60: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", "\n", "W0825 15:58:51.114169 140470926227264 lazy_loader.py:50] \n", "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", "For more information, please see:\n", " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", " * https://github.com/tensorflow/addons\n", " * https://github.com/tensorflow/io (for I/O related ops)\n", "If you depend on functionality not listed there, please file an issue.\n", "\n", "W0825 15:58:51.115457 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n", "\n", "W0825 15:58:51.116125 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n", "\n" ] } ], "source": [ "# import our command line tools\n", "from magenta.models.sketch_rnn.sketch_rnn_train import *\n", "from magenta.models.sketch_rnn.model import *\n", "from magenta.models.sketch_rnn.utils import *\n", "from magenta.models.sketch_rnn.rnn import *" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": {}, "colab_type": "code", "id": "GBde4xkEFQuX" }, "outputs": [], "source": [ "# little function that displays vector images and saves them to .svg\n", "def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):\n", " tf.gfile.MakeDirs(os.path.dirname(svg_filename))\n", " min_x, max_x, min_y, max_y = get_bounds(data, factor)\n", " dims = (50 + max_x - min_x, 50 + max_y - min_y)\n", " dwg = svgwrite.Drawing(svg_filename, size=dims)\n", " dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))\n", " lift_pen = 1\n", " abs_x = 25 - min_x \n", " abs_y = 25 - min_y\n", " p = \"M%s,%s \" % (abs_x, abs_y)\n", " command = \"m\"\n", " for i in xrange(len(data)):\n", " if (lift_pen == 1):\n", " command = \"m\"\n", " elif (command != \"l\"):\n", " command = \"l\"\n", " else:\n", " command = \"\"\n", " x = float(data[i,0])/factor\n", " y = float(data[i,1])/factor\n", " lift_pen = data[i, 2]\n", " p += command+str(x)+\",\"+str(y)+\" \"\n", " the_color = \"black\"\n", " stroke_width = 1\n", " dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill(\"none\"))\n", " dwg.save()\n", " display(SVG(dwg.tostring()))\n", "\n", "# generate a 2D grid of many vector drawings\n", "def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):\n", " def get_start_and_end(x):\n", " x = np.array(x)\n", " x = x[:, 0:2]\n", " x_start = x[0]\n", " x_end = x.sum(axis=0)\n", " x = x.cumsum(axis=0)\n", " x_max = x.max(axis=0)\n", " x_min = x.min(axis=0)\n", " center_loc = (x_max+x_min)*0.5\n", " return x_start-center_loc, x_end\n", " x_pos = 0.0\n", " y_pos = 0.0\n", " result = [[x_pos, y_pos, 1]]\n", " for sample in s_list:\n", " s = sample[0]\n", " grid_loc = sample[1]\n", " grid_y = grid_loc[0]*grid_space+grid_space*0.5\n", " grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5\n", " start_loc, delta_pos = get_start_and_end(s)\n", "\n", " loc_x = start_loc[0]\n", " loc_y = start_loc[1]\n", " new_x_pos = grid_x+loc_x\n", " new_y_pos = grid_y+loc_y\n", " result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])\n", "\n", " result += s.tolist()\n", " result[-1][2] = 1\n", " x_pos = new_x_pos+delta_pos[0]\n", " y_pos = new_y_pos+delta_pos[1]\n", " return np.array(result)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "if7-UyxzFQuY" }, "source": [ "define the path of the model you want to load, and also the path of the dataset" ] }, { "cell_type": "code", "execution_count": 139, "metadata": { "colab": {}, "colab_type": "code", "id": "Dipv1EbsFQuZ" }, "outputs": [], "source": [ "data_dir = 'datasets/naam4'\n", "model_dir = 'models/naam4'" ] }, { "cell_type": "code", "execution_count": 140, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "colab_type": "code", "id": "eaSqI0fIFQub", "outputId": "06df45a6-cc86-4f50-802e-25ae185037f7" }, "outputs": [], "source": [ "# download_pretrained_models(models_root_dir=models_root_dir)" ] }, { "cell_type": "code", "execution_count": 153, "metadata": { "colab": {}, "colab_type": "code", "id": "G4sRuxyn_1aO" }, "outputs": [], "source": [ "def load_env_compatible(data_dir, model_dir):\n", " \"\"\"Loads environment for inference mode, used in jupyter notebook.\"\"\"\n", " # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py\n", " # to work with depreciated tf.HParams functionality\n", " model_params = sketch_rnn_model.get_default_hparams()\n", " with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:\n", " data = json.load(f)\n", " fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']\n", " for fix in fix_list:\n", " data[fix] = (data[fix] == 1)\n", " model_params.parse_json(json.dumps(data))\n", " return load_dataset(data_dir, model_params, inference_mode=True)\n", "\n", "def load_model_compatible(model_dir):\n", " \"\"\"Loads model for inference mode, used in jupyter notebook.\"\"\"\n", " # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py\n", " # to work with depreciated tf.HParams functionality\n", " model_params = sketch_rnn_model.get_default_hparams()\n", " with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:\n", " data = json.load(f)\n", " fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']\n", " for fix in fix_list:\n", " data[fix] = (data[fix] == 1)\n", " model_params.parse_json(json.dumps(data))\n", "\n", " model_params.batch_size = 1 # only sample one at a time\n", " eval_model_params = sketch_rnn_model.copy_hparams(model_params)\n", " eval_model_params.use_input_dropout = 0\n", " eval_model_params.use_recurrent_dropout = 0\n", " eval_model_params.use_output_dropout = 0\n", " eval_model_params.is_training = 0\n", " sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)\n", " sample_model_params.max_seq_len = 1 # sample one point at a time\n", " return [model_params, eval_model_params, sample_model_params]" ] }, { "cell_type": "code", "execution_count": 154, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 153 }, "colab_type": "code", "id": "9m-jSAb3FQuf", "outputId": "debc045d-d15a-4b30-f747-fa4bcbd069fd" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "I0825 17:10:28.492715 140470926227264 sketch_rnn_train.py:142] Loaded 161/161/161 from diede.npz\n", "I0825 17:10:28.666448 140470926227264 sketch_rnn_train.py:142] Loaded 100/100/100 from lijn.npz\n", "I0825 17:10:29.123052 140470926227264 sketch_rnn_train.py:142] Loaded 100/100/100 from blokletters.npz\n", "I0825 17:10:29.189439 140470926227264 sketch_rnn_train.py:159] Dataset combined: 1083 (361/361/361), avg len 234\n", "I0825 17:10:29.190502 140470926227264 sketch_rnn_train.py:166] model_params.max_seq_len 614.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "total images <= max_seq_len is 361\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "I0825 17:10:29.655039 140470926227264 sketch_rnn_train.py:209] normalizing_scale_factor 55.2581.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "total images <= max_seq_len is 361\n", "total images <= max_seq_len is 361\n" ] } ], "source": [ "[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(data_dir, model_dir)" ] }, { "cell_type": "code", "execution_count": 155, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 479 }, "colab_type": "code", "id": "1pHS8TSgFQui", "outputId": "50b0e14d-ff0f-43bf-d996-90e9e6a1491e" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "I0825 17:10:34.129225 140470926227264 model.py:87] Model using gpu.\n", "I0825 17:10:34.139610 140470926227264 model.py:175] Input dropout mode = False.\n", "I0825 17:10:34.141004 140470926227264 model.py:176] Output dropout mode = False.\n", "I0825 17:10:34.143371 140470926227264 model.py:177] Recurrent dropout mode = False.\n", "I0825 17:10:42.551079 140470926227264 model.py:87] Model using gpu.\n", "I0825 17:10:42.553035 140470926227264 model.py:175] Input dropout mode = 0.\n", "I0825 17:10:42.554712 140470926227264 model.py:176] Output dropout mode = 0.\n", "I0825 17:10:42.556115 140470926227264 model.py:177] Recurrent dropout mode = 0.\n", "I0825 17:10:43.679191 140470926227264 model.py:87] Model using gpu.\n", "I0825 17:10:43.681183 140470926227264 model.py:175] Input dropout mode = 0.\n", "I0825 17:10:43.682615 140470926227264 model.py:176] Output dropout mode = 0.\n", "I0825 17:10:43.683944 140470926227264 model.py:177] Recurrent dropout mode = 0.\n" ] } ], "source": [ "# construct the sketch-rnn model here:\n", "reset_graph()\n", "model = Model(hps_model)\n", "eval_model = Model(eval_hps_model, reuse=True)\n", "sample_model = Model(sample_hps_model, reuse=True)" ] }, { "cell_type": "code", "execution_count": 156, "metadata": { "colab": {}, "colab_type": "code", "id": "1gxYLPTQFQuk" }, "outputs": [], "source": [ "sess = tf.InteractiveSession()\n", "sess.run(tf.global_variables_initializer())" ] }, { "cell_type": "code", "execution_count": 199, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "bVlDyfN_FQum", "outputId": "fb41ce20-4c7f-4991-e9f6-559ea9b34a31" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "48.0\n" ] }, { "ename": "ValueError", "evalue": "The passed save_path is not a valid checkpoint: models/naam4/vector-4100", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mckpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_checkpoint_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"-\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0msaver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"models/naam4/vector-4100\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py\u001b[0m in \u001b[0;36mrestore\u001b[0;34m(self, sess, save_path)\u001b[0m\n\u001b[1;32m 1276\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheckpoint_management\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1277\u001b[0m raise ValueError(\"The passed save_path is not a valid checkpoint: \" +\n\u001b[0;32m-> 1278\u001b[0;31m compat.as_text(save_path))\n\u001b[0m\u001b[1;32m 1279\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1280\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Restoring parameters from %s\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: The passed save_path is not a valid checkpoint: models/naam4/vector-4100" ] } ], "source": [ "# loads the weights from checkpoint into our model\n", "# load_checkpoint(sess, model_dir)\n", "saver = tf.train.Saver(tf.global_variables())\n", "ckpt = tf.train.get_checkpoint_state(model_dir)\n", "print(int(ckpt.model_checkpoint_path.split(\"-\")[-1])/100)\n", "# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\n", "saver.restore(sess, \"models/naam4/vector-4100\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EOblwpFeFQuq" }, "source": [ "We define two convenience functions to encode a stroke into a latent vector, and decode from latent vector to stroke." ] }, { "cell_type": "code", "execution_count": 159, "metadata": { "colab": {}, "colab_type": "code", "id": "tMFlV487FQur" }, "outputs": [], "source": [ "def encode(input_strokes):\n", " strokes = to_big_strokes(input_strokes, 614).tolist()\n", " strokes.insert(0, [0, 0, 1, 0, 0])\n", " seq_len = [len(input_strokes)]\n", " print(seq_len)\n", " draw_strokes(to_normal_strokes(np.array(strokes)))\n", " print(np.array([strokes]).shape)\n", " return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]" ] }, { "cell_type": "code", "execution_count": 160, "metadata": { "colab": {}, "colab_type": "code", "id": "1D5CV7ZlFQut" }, "outputs": [], "source": [ "def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2):\n", " z = None\n", " if z_input is not None:\n", " z = [z_input]\n", " sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)\n", " strokes = to_normal_strokes(sample_strokes)\n", " if draw_mode:\n", " draw_strokes(strokes, factor)\n", " return strokes" ] }, { "cell_type": "code", "execution_count": 201, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 123 }, "colab_type": "code", "id": "fUOAvRQtFQuw", "outputId": "c8e9a1c3-28db-4263-ac67-62ffece1e1e0" }, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# get a sample drawing from the test set, and render it to .svg\n", "stroke = test_set.random_sample()\n", "stroke=test_set.strokes[252]\n", "draw_strokes(stroke)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "j114Re2JFQuz" }, "source": [ "Let's try to encode the sample stroke into latent vector $z$" ] }, { "cell_type": "code", "execution_count": 190, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 123 }, "colab_type": "code", "id": "DBRjPBo-FQu0", "outputId": "e089dc78-88e3-44c6-ed7e-f1844471f47f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[244]\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "(1, 615, 5)\n" ] } ], "source": [ "z = encode(stroke)" ] }, { "cell_type": "code", "execution_count": 197, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 124 }, "colab_type": "code", "id": "-37v6eZLFQu5", "outputId": "5ddac2f2-5b3b-4cd7-b81f-7a8fa374aa6b" }, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_ = decode(z,temperature=1) # convert z back to drawing at temperature of 0.8\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "M5ft6IEBFQu9" }, "source": [ "Create generated grid at various temperatures from 0.1 to 1.0" ] }, { "cell_type": "code", "execution_count": 102, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 130 }, "colab_type": "code", "id": "BuhaZI0aFQu9", "outputId": "d87d4b00-30c2-4302-bec8-46566ef26922", "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.1\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0.2\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0.30000000000000004\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0.4\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0.5\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0.6\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0.7000000000000001\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0.8\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "0.9\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for i in range(10):\n", " temp = .1 + i*.1\n", " print(temp)\n", " for j in range(5):\n", " stroke = decode(draw_mode=False, temperature=temp)\n", " draw_strokes(stroke)\n", " \n", " \n", "# stroke_grid = make_grid_svg(stroke_list)\n", "# draw_strokes(stroke_grid)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4xiwp3_DFQvB" }, "source": [ "Latent Space Interpolation Example between $z_0$ and $z_1$" ] }, { "cell_type": "code", "execution_count": 83, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 123 }, "colab_type": "code", "id": "WSX0uvZTFQvD", "outputId": "cd67af4e-5ae6-4327-876e-e1385dadbafc" }, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# get a sample drawing from the test set, and render it to .svg\n", "z0 = z\n", "_ = decode(z0)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 194 }, "colab_type": "code", "id": "jQf99TxOFQvH", "outputId": "4265bd5f-8c66-494e-b26e-d3ac874d69bb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[425]\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "(1, 615, 5)\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "stroke = test_set.random_sample()\n", "z1 = encode(stroke)\n", "_ = decode(z1)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "tDqJR8_eFQvK" }, "source": [ "Now we interpolate between sheep $z_0$ and sheep $z_1$" ] }, { "cell_type": "code", "execution_count": 84, "metadata": { "colab": {}, "colab_type": "code", "id": "_YkPNL5SFQvL" }, "outputs": [], "source": [ "z_list = [] # interpolate spherically between z0 and z1\n", "N = 10\n", "for t in np.linspace(0, 1, N):\n", " z_list.append(slerp(z0, z1, t))" ] }, { "cell_type": "code", "execution_count": 85, "metadata": { "colab": {}, "colab_type": "code", "id": "UoM-W1tQFQvM" }, "outputs": [ { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# for every latent vector in z_list, sample a vector image\n", "reconstructions = []\n", "for i in range(N):\n", " reconstructions.append([decode(z_list[i], draw_mode=True), [0, i]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 122 }, "colab_type": "code", "id": "mTqmlL6GFQvQ", "outputId": "062e015f-29c6-4e77-c6db-e403d5cabd59" }, "outputs": [], "source": [ "stroke_grid = make_grid_svg(reconstructions)\n", "draw_strokes(stroke_grid)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "vFwPna6uFQvS" }, "source": [ "Let's load the Flamingo Model, and try Unconditional (Decoder-Only) Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "HH-YclgNFQvT" }, "outputs": [], "source": [ "model_dir = '/tmp/sketch_rnn/models/flamingo/lstm_uncond'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "-Znvy3KxFQvU" }, "outputs": [], "source": [ "[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 221 }, "colab_type": "code", "id": "cqDNK1cYFQvZ", "outputId": "d346d57c-f51a-4286-ba55-705bc27d4d0d" }, "outputs": [], "source": [ "# construct the sketch-rnn model here:\n", "reset_graph()\n", "model = Model(hps_model)\n", "eval_model = Model(eval_hps_model, reuse=True)\n", "sample_model = Model(sample_hps_model, reuse=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "7wzerSI6FQvd" }, "outputs": [], "source": [ "sess = tf.InteractiveSession()\n", "sess.run(tf.global_variables_initializer())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "6mzk8KjOFQvf", "outputId": "c450a6c6-22ee-4a58-8451-443462b42d58" }, "outputs": [], "source": [ "# loads the weights from checkpoint into our model\n", "load_checkpoint(sess, model_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "X88CgcyuFQvh" }, "outputs": [], "source": [ "# randomly unconditionally generate 10 examples\n", "N = 10\n", "reconstructions = []\n", "for i in range(N):\n", " reconstructions.append([decode(temperature=0.5, draw_mode=False), [0, i]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 149 }, "colab_type": "code", "id": "k57REtd_FQvj", "outputId": "8bd69652-9d1d-475e-fc64-f205cf6b9ed1" }, "outputs": [], "source": [ "stroke_grid = make_grid_svg(reconstructions)\n", "draw_strokes(stroke_grid)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "L-rJ0iUQFQvl" }, "source": [ "Let's load the owl model, and generate two sketches using two random IID gaussian latent vectors" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "of4SWwGdFQvm" }, "outputs": [], "source": [ "model_dir = '/tmp/sketch_rnn/models/owl/lstm'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 255 }, "colab_type": "code", "id": "jJiSZFQeFQvp", "outputId": "f84360ca-c2be-482f-db57-41b5ecc05768" }, "outputs": [], "source": [ "[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)\n", "# construct the sketch-rnn model here:\n", "reset_graph()\n", "model = Model(hps_model)\n", "eval_model = Model(eval_hps_model, reuse=True)\n", "sample_model = Model(sample_hps_model, reuse=True)\n", "sess = tf.InteractiveSession()\n", "sess.run(tf.global_variables_initializer())\n", "# loads the weights from checkpoint into our model\n", "load_checkpoint(sess, model_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 141 }, "colab_type": "code", "id": "vR4TDoi5FQvr", "outputId": "db08cb2c-952c-4949-d2b0-94c11351264b" }, "outputs": [], "source": [ "z_0 = np.random.randn(eval_model.hps.z_size)\n", "_ = decode(z_0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 124 }, "colab_type": "code", "id": "ZX23lTnpFQvt", "outputId": "247052f2-a0f3-4046-83d6-d08e0429fafb" }, "outputs": [], "source": [ "z_1 = np.random.randn(eval_model.hps.z_size)\n", "_ = decode(z_1)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "7FjQsF_2FQvv" }, "source": [ "Let's interpolate between the two owls $z_0$ and $z_1$" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "u6G37E8_FQvw" }, "outputs": [], "source": [ "z_list = [] # interpolate spherically between z_0 and z_1\n", "N = 10\n", "for t in np.linspace(0, 1, N):\n", " z_list.append(slerp(z_0, z_1, t))\n", "# for every latent vector in z_list, sample a vector image\n", "reconstructions = []\n", "for i in range(N):\n", " reconstructions.append([decode(z_list[i], draw_mode=False, temperature=0.1), [0, i]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 149 }, "colab_type": "code", "id": "OULjMktmFQvx", "outputId": "94b7b68e-9c57-4a1b-b216-83770fa4be81" }, "outputs": [], "source": [ "stroke_grid = make_grid_svg(reconstructions)\n", "draw_strokes(stroke_grid)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "OiXNC-YsFQv0" }, "source": [ "Let's load the model trained on both cats and buses! catbus!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "SL7WpDDQFQv0" }, "outputs": [], "source": [ "model_dir = '/tmp/sketch_rnn/models/catbus/lstm'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 255 }, "colab_type": "code", "id": "Cvk5WOqHFQv2", "outputId": "8081d53d-52d6-4d18-f973-a9dd44c897f2" }, "outputs": [], "source": [ "[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)\n", "# construct the sketch-rnn model here:\n", "reset_graph()\n", "model = Model(hps_model)\n", "eval_model = Model(eval_hps_model, reuse=True)\n", "sample_model = Model(sample_hps_model, reuse=True)\n", "sess = tf.InteractiveSession()\n", "sess.run(tf.global_variables_initializer())\n", "# loads the weights from checkpoint into our model\n", "load_checkpoint(sess, model_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 106 }, "colab_type": "code", "id": "icvlBPVkFQv5", "outputId": "f7b415fe-4d65-4b00-c0eb-fb592597dba2" }, "outputs": [], "source": [ "z_1 = np.random.randn(eval_model.hps.z_size)\n", "z_1 = z\n", "_ = decode(z_1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 88 }, "colab_type": "code", "id": "uaNxd0LuFQv-", "outputId": "4de5ee9a-cf14-49f4-e5f5-399a0d0b8215" }, "outputs": [], "source": [ "z_0 = np.random.randn(eval_model.hps.z_size)\n", "_ = decode(z_0)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "VtSYkS6mFQwC" }, "source": [ "Let's interpolate between a cat and a bus!!!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "qIDYUxBEFQwD" }, "outputs": [], "source": [ "z_list = [] # interpolate spherically between z_1 and z_0\n", "N = 50\n", "for t in np.linspace(0, 1, N):\n", " z_list.append(slerp(z_0, z_1, t))\n", "# for every latent vector in z_list, sample a vector image\n", "reconstructions = []\n", "for i in range(N):\n", " reconstructions.append([decode(z_list[i], draw_mode=False, temperature=0.15), [0, i]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 112 }, "colab_type": "code", "id": "ZHmnSjSaFQwH", "outputId": "38fe3c7e-698b-4b19-8851-e7f3ff037744" }, "outputs": [], "source": [ "stroke_grid = make_grid_svg(reconstructions)\n", "draw_strokes(stroke_grid)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "flZ_OgzCFQwJ" }, "source": [ "Why stop here? Let's load the model trained on both elephants and pigs!!!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "S8WwK8FPFQwK" }, "outputs": [], "source": [ "model_dir = '/tmp/sketch_rnn/models/elephantpig/lstm'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 255 }, "colab_type": "code", "id": "meOH4AFXFQwM", "outputId": "764938a7-bbdc-4732-e688-a8a278ab3089" }, "outputs": [], "source": [ "[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)\n", "# construct the sketch-rnn model here:\n", "reset_graph()\n", "model = Model(hps_model)\n", "eval_model = Model(eval_hps_model, reuse=True)\n", "sample_model = Model(sample_hps_model, reuse=True)\n", "sess = tf.InteractiveSession()\n", "sess.run(tf.global_variables_initializer())\n", "# loads the weights from checkpoint into our model\n", "load_checkpoint(sess, model_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 121 }, "colab_type": "code", "id": "foZiiYPdFQwO", "outputId": "a09fc4fb-110f-4280-8515-c9b673cb6b90" }, "outputs": [], "source": [ "z_0 = np.random.randn(eval_model.hps.z_size)\n", "_ = decode(z_0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 163 }, "colab_type": "code", "id": "6Gaz3QG1FQwS", "outputId": "0cfc279c-1c59-419f-86d4-ed74d5e38a26" }, "outputs": [], "source": [ "z_1 = np.random.randn(eval_model.hps.z_size)\n", "_ = decode(z_1)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oVtr7NnGFQwU" }, "source": [ "Tribute to an episode of [South Park](https://en.wikipedia.org/wiki/An_Elephant_Makes_Love_to_a_Pig): The interpolation between an Elephant and a Pig" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "lJs9JbROFQwU" }, "outputs": [], "source": [ "z_list = [] # interpolate spherically between z_1 and z_0\n", "N = 10\n", "for t in np.linspace(0, 1, N):\n", " z_list.append(slerp(z_0, z_1, t))\n", "# for every latent vector in z_list, sample a vector image\n", "reconstructions = []\n", "for i in range(N):\n", " reconstructions.append([decode(z_list[i], draw_mode=False, temperature=0.15), [0, i]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "0FOuNfJMFQwW" }, "outputs": [], "source": [ "stroke_grid = make_grid_svg(reconstructions, grid_space_x=25.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 130 }, "colab_type": "code", "id": "bZ6zpdiMFQwX", "outputId": "70679bd1-4dba-4c08-b39f-bbde81d22019" }, "outputs": [], "source": [ "draw_strokes(stroke_grid, factor=0.3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "KUgVRGnSFQwa" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "name": "Sketch_RNN.ipynb", "provenance": [], "version": "0.3.2" }, "kernelspec": { "display_name": "sketchrnn", "language": "python", "name": "sketchrnn" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 1 }