From 35a398c42e7d815819313815350b2fb780fbcce9 Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Sun, 25 Aug 2019 17:19:27 +0200 Subject: [PATCH] work & thoughts in progress --- README.md | 4 + Sketch_RNN.ipynb | 2227 +++++++++++++++++++++++++++++++++++++++++++++ create_card.py | 212 +++++ create_dataset.py | 150 +++ generate_svg.py | 209 +++++ 5 files changed, 2802 insertions(+) create mode 100644 README.md create mode 100644 Sketch_RNN.ipynb create mode 100644 create_card.py create mode 100644 create_dataset.py create mode 100644 generate_svg.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..1923bcb --- /dev/null +++ b/README.md @@ -0,0 +1,4 @@ +Successfull on naam-simple: +``` +sketch_rnn_train --log_root=models/naam-simple --data_dir=datasets/naam-simple --hparams="data_set=[diede.npz],dec_model=layer_norm,dec_rnn_size=200,enc_model=layer_norm,enc_rnn_size=200,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=False,num_steps=1000" +``` diff --git a/Sketch_RNN.ipynb b/Sketch_RNN.ipynb new file mode 100644 index 0000000..dd9e1e9 --- /dev/null +++ b/Sketch_RNN.ipynb @@ -0,0 +1,2227 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tDybPQiEFQuJ" + }, + "source": [ + "In this notebook, we will show how to load pre-trained models and draw things with sketch-rnn" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "k0GqvYgB9JLC" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", + "/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" + ] + } + ], + "source": [ + "# import the required libraries\n", + "import numpy as np\n", + "import time\n", + "import random\n", + "import pickle\n", + "import codecs\n", + "import collections\n", + "import os\n", + "import math\n", + "import json\n", + "import tensorflow as tf\n", + "from six.moves import xrange" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "UI4ZC__4FQuL" + }, + "outputs": [], + "source": [ + "# libraries required for visualisation:\n", + "from IPython.display import SVG, display\n", + "import PIL\n", + "from PIL import Image\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# set numpy output to something sensible\n", + "np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "D7ObpAUh9jrk" + }, + "outputs": [], + "source": [ + "# !pip install -qU svgwrite" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "4xYY-TUd9aiD" + }, + "outputs": [], + "source": [ + "import svgwrite # conda install -c omnia svgwrite=1.1.6" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 34 + }, + "colab_type": "code", + "id": "NzPSD-XRFQuP", + "outputId": "daa0dd33-6d59-4d15-f437-d8ec787c8884" + }, + "outputs": [], + "source": [ + "tf.logging.info(\"TensorFlow Version: %s\", tf.__version__)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "LebxcF4p90OR" + }, + "outputs": [], + "source": [ + "# !pip install -q magenta" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "NkFS0E1zFQuU" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: Logging before flag parsing goes to stderr.\n", + "W0825 15:58:50.188074 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/pipelines/statistics.py:132: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n", + "\n", + "/home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/numba/errors.py:131: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9\n", + " warnings.warn(msg)\n", + "W0825 15:58:50.811670 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/music/note_sequence_io.py:60: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.\n", + "\n", + "W0825 15:58:51.114169 140470926227264 lazy_loader.py:50] \n", + "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", + "For more information, please see:\n", + " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", + " * https://github.com/tensorflow/addons\n", + " * https://github.com/tensorflow/io (for I/O related ops)\n", + "If you depend on functionality not listed there, please file an issue.\n", + "\n", + "W0825 15:58:51.115457 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.\n", + "\n", + "W0825 15:58:51.116125 140470926227264 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.\n", + "\n" + ] + } + ], + "source": [ + "# import our command line tools\n", + "from magenta.models.sketch_rnn.sketch_rnn_train import *\n", + "from magenta.models.sketch_rnn.model import *\n", + "from magenta.models.sketch_rnn.utils import *\n", + "from magenta.models.sketch_rnn.rnn import *" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "GBde4xkEFQuX" + }, + "outputs": [], + "source": [ + "# little function that displays vector images and saves them to .svg\n", + "def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):\n", + " tf.gfile.MakeDirs(os.path.dirname(svg_filename))\n", + " min_x, max_x, min_y, max_y = get_bounds(data, factor)\n", + " dims = (50 + max_x - min_x, 50 + max_y - min_y)\n", + " dwg = svgwrite.Drawing(svg_filename, size=dims)\n", + " dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))\n", + " lift_pen = 1\n", + " abs_x = 25 - min_x \n", + " abs_y = 25 - min_y\n", + " p = \"M%s,%s \" % (abs_x, abs_y)\n", + " command = \"m\"\n", + " for i in xrange(len(data)):\n", + " if (lift_pen == 1):\n", + " command = \"m\"\n", + " elif (command != \"l\"):\n", + " command = \"l\"\n", + " else:\n", + " command = \"\"\n", + " x = float(data[i,0])/factor\n", + " y = float(data[i,1])/factor\n", + " lift_pen = data[i, 2]\n", + " p += command+str(x)+\",\"+str(y)+\" \"\n", + " the_color = \"black\"\n", + " stroke_width = 1\n", + " dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill(\"none\"))\n", + " dwg.save()\n", + " display(SVG(dwg.tostring()))\n", + "\n", + "# generate a 2D grid of many vector drawings\n", + "def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):\n", + " def get_start_and_end(x):\n", + " x = np.array(x)\n", + " x = x[:, 0:2]\n", + " x_start = x[0]\n", + " x_end = x.sum(axis=0)\n", + " x = x.cumsum(axis=0)\n", + " x_max = x.max(axis=0)\n", + " x_min = x.min(axis=0)\n", + " center_loc = (x_max+x_min)*0.5\n", + " return x_start-center_loc, x_end\n", + " x_pos = 0.0\n", + " y_pos = 0.0\n", + " result = [[x_pos, y_pos, 1]]\n", + " for sample in s_list:\n", + " s = sample[0]\n", + " grid_loc = sample[1]\n", + " grid_y = grid_loc[0]*grid_space+grid_space*0.5\n", + " grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5\n", + " start_loc, delta_pos = get_start_and_end(s)\n", + "\n", + " loc_x = start_loc[0]\n", + " loc_y = start_loc[1]\n", + " new_x_pos = grid_x+loc_x\n", + " new_y_pos = grid_y+loc_y\n", + " result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])\n", + "\n", + " result += s.tolist()\n", + " result[-1][2] = 1\n", + " x_pos = new_x_pos+delta_pos[0]\n", + " y_pos = new_y_pos+delta_pos[1]\n", + " return np.array(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "if7-UyxzFQuY" + }, + "source": [ + "define the path of the model you want to load, and also the path of the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Dipv1EbsFQuZ" + }, + "outputs": [], + "source": [ + "data_dir = 'datasets/naam4'\n", + "model_dir = 'models/naam4'" + ] + }, + { + "cell_type": "code", + "execution_count": 140, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "colab_type": "code", + "id": "eaSqI0fIFQub", + "outputId": "06df45a6-cc86-4f50-802e-25ae185037f7" + }, + "outputs": [], + "source": [ + "# download_pretrained_models(models_root_dir=models_root_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "G4sRuxyn_1aO" + }, + "outputs": [], + "source": [ + "def load_env_compatible(data_dir, model_dir):\n", + " \"\"\"Loads environment for inference mode, used in jupyter notebook.\"\"\"\n", + " # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py\n", + " # to work with depreciated tf.HParams functionality\n", + " model_params = sketch_rnn_model.get_default_hparams()\n", + " with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:\n", + " data = json.load(f)\n", + " fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']\n", + " for fix in fix_list:\n", + " data[fix] = (data[fix] == 1)\n", + " model_params.parse_json(json.dumps(data))\n", + " return load_dataset(data_dir, model_params, inference_mode=True)\n", + "\n", + "def load_model_compatible(model_dir):\n", + " \"\"\"Loads model for inference mode, used in jupyter notebook.\"\"\"\n", + " # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py\n", + " # to work with depreciated tf.HParams functionality\n", + " model_params = sketch_rnn_model.get_default_hparams()\n", + " with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:\n", + " data = json.load(f)\n", + " fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']\n", + " for fix in fix_list:\n", + " data[fix] = (data[fix] == 1)\n", + " model_params.parse_json(json.dumps(data))\n", + "\n", + " model_params.batch_size = 1 # only sample one at a time\n", + " eval_model_params = sketch_rnn_model.copy_hparams(model_params)\n", + " eval_model_params.use_input_dropout = 0\n", + " eval_model_params.use_recurrent_dropout = 0\n", + " eval_model_params.use_output_dropout = 0\n", + " eval_model_params.is_training = 0\n", + " sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)\n", + " sample_model_params.max_seq_len = 1 # sample one point at a time\n", + " return [model_params, eval_model_params, sample_model_params]" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 153 + }, + "colab_type": "code", + "id": "9m-jSAb3FQuf", + "outputId": "debc045d-d15a-4b30-f747-fa4bcbd069fd" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0825 17:10:28.492715 140470926227264 sketch_rnn_train.py:142] Loaded 161/161/161 from diede.npz\n", + "I0825 17:10:28.666448 140470926227264 sketch_rnn_train.py:142] Loaded 100/100/100 from lijn.npz\n", + "I0825 17:10:29.123052 140470926227264 sketch_rnn_train.py:142] Loaded 100/100/100 from blokletters.npz\n", + "I0825 17:10:29.189439 140470926227264 sketch_rnn_train.py:159] Dataset combined: 1083 (361/361/361), avg len 234\n", + "I0825 17:10:29.190502 140470926227264 sketch_rnn_train.py:166] model_params.max_seq_len 614.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total images <= max_seq_len is 361\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0825 17:10:29.655039 140470926227264 sketch_rnn_train.py:209] normalizing_scale_factor 55.2581.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "total images <= max_seq_len is 361\n", + "total images <= max_seq_len is 361\n" + ] + } + ], + "source": [ + "[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(data_dir, model_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 155, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 479 + }, + "colab_type": "code", + "id": "1pHS8TSgFQui", + "outputId": "50b0e14d-ff0f-43bf-d996-90e9e6a1491e" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "I0825 17:10:34.129225 140470926227264 model.py:87] Model using gpu.\n", + "I0825 17:10:34.139610 140470926227264 model.py:175] Input dropout mode = False.\n", + "I0825 17:10:34.141004 140470926227264 model.py:176] Output dropout mode = False.\n", + "I0825 17:10:34.143371 140470926227264 model.py:177] Recurrent dropout mode = False.\n", + "I0825 17:10:42.551079 140470926227264 model.py:87] Model using gpu.\n", + "I0825 17:10:42.553035 140470926227264 model.py:175] Input dropout mode = 0.\n", + "I0825 17:10:42.554712 140470926227264 model.py:176] Output dropout mode = 0.\n", + "I0825 17:10:42.556115 140470926227264 model.py:177] Recurrent dropout mode = 0.\n", + "I0825 17:10:43.679191 140470926227264 model.py:87] Model using gpu.\n", + "I0825 17:10:43.681183 140470926227264 model.py:175] Input dropout mode = 0.\n", + "I0825 17:10:43.682615 140470926227264 model.py:176] Output dropout mode = 0.\n", + "I0825 17:10:43.683944 140470926227264 model.py:177] Recurrent dropout mode = 0.\n" + ] + } + ], + "source": [ + "# construct the sketch-rnn model here:\n", + "reset_graph()\n", + "model = Model(hps_model)\n", + "eval_model = Model(eval_hps_model, reuse=True)\n", + "sample_model = Model(sample_hps_model, reuse=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1gxYLPTQFQuk" + }, + "outputs": [], + "source": [ + "sess = tf.InteractiveSession()\n", + "sess.run(tf.global_variables_initializer())" + ] + }, + { + "cell_type": "code", + "execution_count": 199, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "colab_type": "code", + "id": "bVlDyfN_FQum", + "outputId": "fb41ce20-4c7f-4991-e9f6-559ea9b34a31" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "48.0\n" + ] + }, + { + "ename": "ValueError", + "evalue": "The passed save_path is not a valid checkpoint: models/naam4/vector-4100", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mckpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_checkpoint_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"-\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0msaver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrestore\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"models/naam4/vector-4100\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py\u001b[0m in \u001b[0;36mrestore\u001b[0;34m(self, sess, save_path)\u001b[0m\n\u001b[1;32m 1276\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheckpoint_management\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint_exists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1277\u001b[0m raise ValueError(\"The passed save_path is not a valid checkpoint: \" +\n\u001b[0;32m-> 1278\u001b[0;31m compat.as_text(save_path))\n\u001b[0m\u001b[1;32m 1279\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1280\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Restoring parameters from %s\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: The passed save_path is not a valid checkpoint: models/naam4/vector-4100" + ] + } + ], + "source": [ + "# loads the weights from checkpoint into our model\n", + "# load_checkpoint(sess, model_dir)\n", + "saver = tf.train.Saver(tf.global_variables())\n", + "ckpt = tf.train.get_checkpoint_state(model_dir)\n", + "print(int(ckpt.model_checkpoint_path.split(\"-\")[-1])/100)\n", + "# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)\n", + "saver.restore(sess, \"models/naam4/vector-4100\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "EOblwpFeFQuq" + }, + "source": [ + "We define two convenience functions to encode a stroke into a latent vector, and decode from latent vector to stroke." + ] + }, + { + "cell_type": "code", + "execution_count": 159, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "tMFlV487FQur" + }, + "outputs": [], + "source": [ + "def encode(input_strokes):\n", + " strokes = to_big_strokes(input_strokes, 614).tolist()\n", + " strokes.insert(0, [0, 0, 1, 0, 0])\n", + " seq_len = [len(input_strokes)]\n", + " print(seq_len)\n", + " draw_strokes(to_normal_strokes(np.array(strokes)))\n", + " print(np.array([strokes]).shape)\n", + " return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 160, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1D5CV7ZlFQut" + }, + "outputs": [], + "source": [ + "def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2):\n", + " z = None\n", + " if z_input is not None:\n", + " z = [z_input]\n", + " sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)\n", + " strokes = to_normal_strokes(sample_strokes)\n", + " if draw_mode:\n", + " draw_strokes(strokes, factor)\n", + " return strokes" + ] + }, + { + "cell_type": "code", + "execution_count": 189, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 123 + }, + "colab_type": "code", + "id": "fUOAvRQtFQuw", + "outputId": "c8e9a1c3-28db-4263-ac67-62ffece1e1e0" + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# get a sample drawing from the test set, and render it to .svg\n", + "stroke = test_set.random_sample()\n", + "stroke=test_set.strokes[202]\n", + "draw_strokes(stroke)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "j114Re2JFQuz" + }, + "source": [ + "Let's try to encode the sample stroke into latent vector $z$" + ] + }, + { + "cell_type": "code", + "execution_count": 190, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 123 + }, + "colab_type": "code", + "id": "DBRjPBo-FQu0", + "outputId": "e089dc78-88e3-44c6-ed7e-f1844471f47f" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[244]\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 615, 5)\n" + ] + } + ], + "source": [ + "z = encode(stroke)" + ] + }, + { + "cell_type": "code", + "execution_count": 197, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 124 + }, + "colab_type": "code", + "id": "-37v6eZLFQu5", + "outputId": "5ddac2f2-5b3b-4cd7-b81f-7a8fa374aa6b" + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_ = decode(z,temperature=1) # convert z back to drawing at temperature of 0.8\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "M5ft6IEBFQu9" + }, + "source": [ + "Create generated grid at various temperatures from 0.1 to 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 130 + }, + "colab_type": "code", + "id": "BuhaZI0aFQu9", + "outputId": "d87d4b00-30c2-4302-bec8-46566ef26922", + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.1\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.2\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.30000000000000004\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.4\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.5\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.6\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.7000000000000001\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.8\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.0\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for i in range(10):\n", + " temp = .1 + i*.1\n", + " print(temp)\n", + " for j in range(5):\n", + " stroke = decode(draw_mode=False, temperature=temp)\n", + " draw_strokes(stroke)\n", + " \n", + " \n", + "# stroke_grid = make_grid_svg(stroke_list)\n", + "# draw_strokes(stroke_grid)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4xiwp3_DFQvB" + }, + "source": [ + "Latent Space Interpolation Example between $z_0$ and $z_1$" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 123 + }, + "colab_type": "code", + "id": "WSX0uvZTFQvD", + "outputId": "cd67af4e-5ae6-4327-876e-e1385dadbafc" + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# get a sample drawing from the test set, and render it to .svg\n", + "z0 = z\n", + "_ = decode(z0)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 194 + }, + "colab_type": "code", + "id": "jQf99TxOFQvH", + "outputId": "4265bd5f-8c66-494e-b26e-d3ac874d69bb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[425]\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 615, 5)\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "stroke = test_set.random_sample()\n", + "z1 = encode(stroke)\n", + "_ = decode(z1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tDqJR8_eFQvK" + }, + "source": [ + "Now we interpolate between sheep $z_0$ and sheep $z_1$" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_YkPNL5SFQvL" + }, + "outputs": [], + "source": [ + "z_list = [] # interpolate spherically between z0 and z1\n", + "N = 10\n", + "for t in np.linspace(0, 1, N):\n", + " z_list.append(slerp(z0, z1, t))" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "UoM-W1tQFQvM" + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# for every latent vector in z_list, sample a vector image\n", + "reconstructions = []\n", + "for i in range(N):\n", + " reconstructions.append([decode(z_list[i], draw_mode=True), [0, i]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 122 + }, + "colab_type": "code", + "id": "mTqmlL6GFQvQ", + "outputId": "062e015f-29c6-4e77-c6db-e403d5cabd59" + }, + "outputs": [], + "source": [ + "stroke_grid = make_grid_svg(reconstructions)\n", + "draw_strokes(stroke_grid)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vFwPna6uFQvS" + }, + "source": [ + "Let's load the Flamingo Model, and try Unconditional (Decoder-Only) Generation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "HH-YclgNFQvT" + }, + "outputs": [], + "source": [ + "model_dir = '/tmp/sketch_rnn/models/flamingo/lstm_uncond'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "-Znvy3KxFQvU" + }, + "outputs": [], + "source": [ + "[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 221 + }, + "colab_type": "code", + "id": "cqDNK1cYFQvZ", + "outputId": "d346d57c-f51a-4286-ba55-705bc27d4d0d" + }, + "outputs": [], + "source": [ + "# construct the sketch-rnn model here:\n", + "reset_graph()\n", + "model = Model(hps_model)\n", + "eval_model = Model(eval_hps_model, reuse=True)\n", + "sample_model = Model(sample_hps_model, reuse=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "7wzerSI6FQvd" + }, + "outputs": [], + "source": [ + "sess = tf.InteractiveSession()\n", + "sess.run(tf.global_variables_initializer())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "colab_type": "code", + "id": "6mzk8KjOFQvf", + "outputId": "c450a6c6-22ee-4a58-8451-443462b42d58" + }, + "outputs": [], + "source": [ + "# loads the weights from checkpoint into our model\n", + "load_checkpoint(sess, model_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "X88CgcyuFQvh" + }, + "outputs": [], + "source": [ + "# randomly unconditionally generate 10 examples\n", + "N = 10\n", + "reconstructions = []\n", + "for i in range(N):\n", + " reconstructions.append([decode(temperature=0.5, draw_mode=False), [0, i]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 149 + }, + "colab_type": "code", + "id": "k57REtd_FQvj", + "outputId": "8bd69652-9d1d-475e-fc64-f205cf6b9ed1" + }, + "outputs": [], + "source": [ + "stroke_grid = make_grid_svg(reconstructions)\n", + "draw_strokes(stroke_grid)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "L-rJ0iUQFQvl" + }, + "source": [ + "Let's load the owl model, and generate two sketches using two random IID gaussian latent vectors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "of4SWwGdFQvm" + }, + "outputs": [], + "source": [ + "model_dir = '/tmp/sketch_rnn/models/owl/lstm'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 255 + }, + "colab_type": "code", + "id": "jJiSZFQeFQvp", + "outputId": "f84360ca-c2be-482f-db57-41b5ecc05768" + }, + "outputs": [], + "source": [ + "[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)\n", + "# construct the sketch-rnn model here:\n", + "reset_graph()\n", + "model = Model(hps_model)\n", + "eval_model = Model(eval_hps_model, reuse=True)\n", + "sample_model = Model(sample_hps_model, reuse=True)\n", + "sess = tf.InteractiveSession()\n", + "sess.run(tf.global_variables_initializer())\n", + "# loads the weights from checkpoint into our model\n", + "load_checkpoint(sess, model_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 141 + }, + "colab_type": "code", + "id": "vR4TDoi5FQvr", + "outputId": "db08cb2c-952c-4949-d2b0-94c11351264b" + }, + "outputs": [], + "source": [ + "z_0 = np.random.randn(eval_model.hps.z_size)\n", + "_ = decode(z_0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 124 + }, + "colab_type": "code", + "id": "ZX23lTnpFQvt", + "outputId": "247052f2-a0f3-4046-83d6-d08e0429fafb" + }, + "outputs": [], + "source": [ + "z_1 = np.random.randn(eval_model.hps.z_size)\n", + "_ = decode(z_1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "7FjQsF_2FQvv" + }, + "source": [ + "Let's interpolate between the two owls $z_0$ and $z_1$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "u6G37E8_FQvw" + }, + "outputs": [], + "source": [ + "z_list = [] # interpolate spherically between z_0 and z_1\n", + "N = 10\n", + "for t in np.linspace(0, 1, N):\n", + " z_list.append(slerp(z_0, z_1, t))\n", + "# for every latent vector in z_list, sample a vector image\n", + "reconstructions = []\n", + "for i in range(N):\n", + " reconstructions.append([decode(z_list[i], draw_mode=False, temperature=0.1), [0, i]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 149 + }, + "colab_type": "code", + "id": "OULjMktmFQvx", + "outputId": "94b7b68e-9c57-4a1b-b216-83770fa4be81" + }, + "outputs": [], + "source": [ + "stroke_grid = make_grid_svg(reconstructions)\n", + "draw_strokes(stroke_grid)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "OiXNC-YsFQv0" + }, + "source": [ + "Let's load the model trained on both cats and buses! catbus!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "SL7WpDDQFQv0" + }, + "outputs": [], + "source": [ + "model_dir = '/tmp/sketch_rnn/models/catbus/lstm'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 255 + }, + "colab_type": "code", + "id": "Cvk5WOqHFQv2", + "outputId": "8081d53d-52d6-4d18-f973-a9dd44c897f2" + }, + "outputs": [], + "source": [ + "[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)\n", + "# construct the sketch-rnn model here:\n", + "reset_graph()\n", + "model = Model(hps_model)\n", + "eval_model = Model(eval_hps_model, reuse=True)\n", + "sample_model = Model(sample_hps_model, reuse=True)\n", + "sess = tf.InteractiveSession()\n", + "sess.run(tf.global_variables_initializer())\n", + "# loads the weights from checkpoint into our model\n", + "load_checkpoint(sess, model_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 106 + }, + "colab_type": "code", + "id": "icvlBPVkFQv5", + "outputId": "f7b415fe-4d65-4b00-c0eb-fb592597dba2" + }, + "outputs": [], + "source": [ + "z_1 = np.random.randn(eval_model.hps.z_size)\n", + "z_1 = z\n", + "_ = decode(z_1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 88 + }, + "colab_type": "code", + "id": "uaNxd0LuFQv-", + "outputId": "4de5ee9a-cf14-49f4-e5f5-399a0d0b8215" + }, + "outputs": [], + "source": [ + "z_0 = np.random.randn(eval_model.hps.z_size)\n", + "_ = decode(z_0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "VtSYkS6mFQwC" + }, + "source": [ + "Let's interpolate between a cat and a bus!!!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "qIDYUxBEFQwD" + }, + "outputs": [], + "source": [ + "z_list = [] # interpolate spherically between z_1 and z_0\n", + "N = 50\n", + "for t in np.linspace(0, 1, N):\n", + " z_list.append(slerp(z_0, z_1, t))\n", + "# for every latent vector in z_list, sample a vector image\n", + "reconstructions = []\n", + "for i in range(N):\n", + " reconstructions.append([decode(z_list[i], draw_mode=False, temperature=0.15), [0, i]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 112 + }, + "colab_type": "code", + "id": "ZHmnSjSaFQwH", + "outputId": "38fe3c7e-698b-4b19-8851-e7f3ff037744" + }, + "outputs": [], + "source": [ + "stroke_grid = make_grid_svg(reconstructions)\n", + "draw_strokes(stroke_grid)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "flZ_OgzCFQwJ" + }, + "source": [ + "Why stop here? Let's load the model trained on both elephants and pigs!!!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "S8WwK8FPFQwK" + }, + "outputs": [], + "source": [ + "model_dir = '/tmp/sketch_rnn/models/elephantpig/lstm'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 255 + }, + "colab_type": "code", + "id": "meOH4AFXFQwM", + "outputId": "764938a7-bbdc-4732-e688-a8a278ab3089" + }, + "outputs": [], + "source": [ + "[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)\n", + "# construct the sketch-rnn model here:\n", + "reset_graph()\n", + "model = Model(hps_model)\n", + "eval_model = Model(eval_hps_model, reuse=True)\n", + "sample_model = Model(sample_hps_model, reuse=True)\n", + "sess = tf.InteractiveSession()\n", + "sess.run(tf.global_variables_initializer())\n", + "# loads the weights from checkpoint into our model\n", + "load_checkpoint(sess, model_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 121 + }, + "colab_type": "code", + "id": "foZiiYPdFQwO", + "outputId": "a09fc4fb-110f-4280-8515-c9b673cb6b90" + }, + "outputs": [], + "source": [ + "z_0 = np.random.randn(eval_model.hps.z_size)\n", + "_ = decode(z_0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 163 + }, + "colab_type": "code", + "id": "6Gaz3QG1FQwS", + "outputId": "0cfc279c-1c59-419f-86d4-ed74d5e38a26" + }, + "outputs": [], + "source": [ + "z_1 = np.random.randn(eval_model.hps.z_size)\n", + "_ = decode(z_1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "oVtr7NnGFQwU" + }, + "source": [ + "Tribute to an episode of [South Park](https://en.wikipedia.org/wiki/An_Elephant_Makes_Love_to_a_Pig): The interpolation between an Elephant and a Pig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "lJs9JbROFQwU" + }, + "outputs": [], + "source": [ + "z_list = [] # interpolate spherically between z_1 and z_0\n", + "N = 10\n", + "for t in np.linspace(0, 1, N):\n", + " z_list.append(slerp(z_0, z_1, t))\n", + "# for every latent vector in z_list, sample a vector image\n", + "reconstructions = []\n", + "for i in range(N):\n", + " reconstructions.append([decode(z_list[i], draw_mode=False, temperature=0.15), [0, i]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "0FOuNfJMFQwW" + }, + "outputs": [], + "source": [ + "stroke_grid = make_grid_svg(reconstructions, grid_space_x=25.0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 130 + }, + "colab_type": "code", + "id": "bZ6zpdiMFQwX", + "outputId": "70679bd1-4dba-4c08-b39f-bbde81d22019" + }, + "outputs": [], + "source": [ + "draw_strokes(stroke_grid, factor=0.3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "KUgVRGnSFQwa" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "name": "Sketch_RNN.ipynb", + "provenance": [], + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "sketchrnn", + "language": "python", + "name": "sketchrnn" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/create_card.py b/create_card.py new file mode 100644 index 0000000..6bb85dd --- /dev/null +++ b/create_card.py @@ -0,0 +1,212 @@ +# import the required libraries +import numpy as np +import time +import random +import pickle +import codecs +import collections +import os +import math +import json +import tensorflow as tf +from six.moves import xrange +import logging +import argparse +import svgwrite +from tqdm import tqdm +import re + + +# import our command line tools +from magenta.models.sketch_rnn.sketch_rnn_train import * +from magenta.models.sketch_rnn.model import * +from magenta.models.sketch_rnn.utils import * +from magenta.models.sketch_rnn.rnn import * + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger('card') + +argParser = argparse.ArgumentParser(description='Create postcard') +argParser.add_argument( + '--data_dir', + type=str, + default='./datasets/naam3', + help='' + ) +argParser.add_argument( + '--model_dir', + type=str, + default='./models/naam3', + ) +argParser.add_argument( + '--output_dir', + type=str, + default='generated/naam3', + ) +argParser.add_argument( + '--width', + type=str, + default='90mm', + ) +argParser.add_argument( + '--height', + type=str, + default='150mm', + ) +argParser.add_argument( + '--rows', + type=int, + default=13, + ) +argParser.add_argument( + '--columns', + type=int, + default=5, + ) +# argParser.add_argument( +# '--output_file', +# type=str, +# default='card.svg', +# ) +argParser.add_argument( + '--verbose', + '-v', + action='store_true', + help='Debug logging' + ) +args = argParser.parse_args() + +dataset_baseheight = 10 + +if args.verbose: + logger.setLevel(logging.DEBUG) + +def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25): + lift_pen = 1 + min_x, max_x, min_y, max_y = get_bounds(strokes, factor) + abs_x = start_x - min_x + abs_y = start_y - min_y + p = "M%s,%s " % (abs_x, abs_y) + # p = "M%s,%s " % (0, 0) + command = "m" + for i in xrange(len(strokes)): + if (lift_pen == 1): + command = "m" + elif (command != "l"): + command = "l" + else: + command = "" + x = float(strokes[i,0])/factor + y = float(strokes[i,1])/factor + lift_pen = strokes[i, 2] + p += f"{command}{x:.5},{y:.5} " + the_color = "black" + stroke_width = 1 + return dwg.path(p).stroke(the_color,stroke_width).fill("none") + +# little function that displays vector images and saves them to .svg +def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'): + tf.gfile.MakeDirs(os.path.dirname(svg_filename)) + min_x, max_x, min_y, max_y = get_bounds(strokes, factor) + dims = (50 + max_x - min_x, 50 + max_y - min_y) + dwg = svgwrite.Drawing(svg_filename, size=dims) + dwg.add(strokesToPath(dwg, strokes, factor)) + dwg.save() + +def load_env_compatible(data_dir, model_dir): + """Loads environment for inference mode, used in jupyter notebook.""" + # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py + # to work with depreciated tf.HParams functionality + model_params = sketch_rnn_model.get_default_hparams() + with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: + data = json.load(f) + fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] + for fix in fix_list: + data[fix] = (data[fix] == 1) + model_params.parse_json(json.dumps(data)) + return load_dataset(data_dir, model_params, inference_mode=True) + +def load_model_compatible(model_dir): + """Loads model for inference mode, used in jupyter notebook.""" + # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py + # to work with depreciated tf.HParams functionality + model_params = sketch_rnn_model.get_default_hparams() + with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: + data = json.load(f) + fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] + for fix in fix_list: + data[fix] = (data[fix] == 1) + model_params.parse_json(json.dumps(data)) + + model_params.batch_size = 1 # only sample one at a time + eval_model_params = sketch_rnn_model.copy_hparams(model_params) + eval_model_params.use_input_dropout = 0 + eval_model_params.use_recurrent_dropout = 0 + eval_model_params.use_output_dropout = 0 + eval_model_params.is_training = 0 + sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) + sample_model_params.max_seq_len = 1 # sample one point at a time + return [model_params, eval_model_params, sample_model_params] + + +# some basic initialisation (done before encode() and decode() as they use these variables) +[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir) +# construct the sketch-rnn model here: +reset_graph() +model = Model(hps_model) +eval_model = Model(eval_hps_model, reuse=True) +sample_model = Model(sample_hps_model, reuse=True) + +sess = tf.InteractiveSession() +sess.run(tf.global_variables_initializer()) +# loads the weights from checkpoint into our model +load_checkpoint(sess, args.model_dir) + +def encode(input_strokes): + """ + Encode input image into vector + """ + strokes = to_big_strokes(input_strokes, 614).tolist() + strokes.insert(0, [0, 0, 1, 0, 0]) + seq_len = [len(input_strokes)] + print(seq_len) + draw_strokes(to_normal_strokes(np.array(strokes))) + print(np.array([strokes]).shape) + return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0] + +def decode(z_input=None, temperature=0.1, factor=0.2): + """ + Decode vector into strokes (image) + """ + z = None + if z_input is not None: + z = [z_input] + sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z) + strokes = to_normal_strokes(sample_strokes) + return strokes + +dims = (args.width, args.height) +width = int(re.findall('\d+',args.width)[0])*10 +height = int(re.findall('\d+',args.height)[0])*10 +# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}") + +item_count = args.rows*args.columns + +factor = dataset_baseheight/ (height/args.rows) + +with tqdm(total=item_count) as pbar: + for row in range(args.rows): + start_y = row * (float(height) / args.rows) + row_temp = .1+row * (1./args.rows) + for column in range(args.columns): + strokes = decode(temperature=row_temp) + fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg') + draw_strokes(strokes, svg_filename=fn) + # current_nr = row * args.columns + column + # temp = .01+current_nr * (1./item_count) + # start_x = column * (float(width) / args.columns) + # path = strokesToPath(dwg, strokes, factor, start_x, start_y) + # dwg.add(path) + pbar.update() +# dwg.save() diff --git a/create_dataset.py b/create_dataset.py new file mode 100644 index 0000000..a81073b --- /dev/null +++ b/create_dataset.py @@ -0,0 +1,150 @@ +from rdp import rdp +import xml.etree.ElementTree as ET +from svg.path import parse_path +import numpy as np +import logging +import argparse +import os + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger('dataset') + +argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test') +argParser.add_argument( + '--src', + type=str, + default='./data/naam', + help='' + ) +argParser.add_argument( + '--dataset_dir', + type=str, + default='./datasets/naam', + ) +argParser.add_argument( + '--multiply', + type=int, + default=2, + help="If you dont have enough items, automatically multiply it so there's at least 100 per set.", + ) +argParser.add_argument( + '--verbose', + '-v', + action='store_true', + help='Debug logging' + ) +args = argParser.parse_args() + +if args.verbose: + logger.setLevel(logging.DEBUG) + + +def getStroke3FromSvg(filename): + """ + Get drawing as stroke-3 format. + Gets each group as different drawing + points as as [dx, dy, pen state] + """ + logger.debug(f"Read {filename}") + s = ET.parse(filename) + root = s.getroot() + groups = root.findall("{http://www.w3.org/2000/svg}g") + + sketches = [] + for group in groups: + svg_paths = group.findall("{http://www.w3.org/2000/svg}path") + paths = [] + + min_x = None + min_y = None + max_y = None + + for p in svg_paths: + path = parse_path(p.get("d")) + + points = [[point.end.real, point.end.imag] for point in path] + x_points = np.array([p[0] for p in points]) + y_points = np.array([p[1] for p in points]) + + if min_x is None: + min_x = min(x_points) + min_y = min(y_points) + max_y = max(y_points) + else: + min_x = min(min_x, min(x_points)) + min_y = min(min_y, min(y_points)) + max_y = max(max_y, max(y_points)) + + points = np.array([[x_points[i], y_points[i]] for i in range(len(points))]) + paths.append(points) + + # scale normalize & crop + scale = 512 / (max_y - min_y) + + prev_x = None + prev_y = None + + strokes = [] + for path in paths: + path[:,0] -= min_x + path[:,1] -= min_y + path *= scale + #simplify using Ramer-Douglas-Peucker., see https://github.com/tensorflow/magenta/tree/master/magenta/models/sketch_rnn + if(len(path) > 800): + logger.debug(f'simplify {len(path)} factor 3.5') + path = rdp(path, epsilon=3.5) + logger.debug(f'\tnow {len(path)}') + if(len(path) > 490): + logger.debug(f'simplify {len(path)} factor 3') + path = rdp(path, epsilon=3) + logger.debug(f'\tnow {len(path)}') + if(len(path) > 300): + logger.debug(f'simplify {len(path)} factor 2') + path = rdp(path, epsilon=2.0) + logger.debug(f'\tnow {len(path)}') + for point in path: + if prev_x is not None and prev_y is not None: + strokes.append([int(point[0] - prev_x), int(point[1] - prev_y), 0]) + + prev_x = point[0] + prev_y = point[1] + + # mark lifting of pen + strokes[-1][2] = 1 + + logger.debug(f"Paths: {len(strokes)}") + # strokes = np.array(strokes, dtype=np.int16) + sketches.append(strokes) + + return sketches + +def main(): + sketches = {} + for dirName, subdirList, fileList in os.walk(args.src): + for fname in fileList: + filename = os.path.join(dirName, fname) + className = fname[:-4].rstrip('0123456789.- ') + if not className in sketches: + sketches[className] = [] + sketches[className].extend(getStroke3FromSvg(filename)) + + + for className in sketches: + filename = os.path.join(args.dataset_dir, className + '.npz') + itemsForClass = len(sketches[className]) + if itemsForClass < 100: + logger.info(f"Loop to have at least 100 for class {className} (now {itemsForClass})") + extras = [] + for i in range(100 - itemsForClass): + extras.append(sketches[className][i % itemsForClass]) + sketches[className].extend(extras) + logger.debug(f"Now {len(sketches[className])} samples for {className}") + sets = sketches[className] + # exit() + np.savez_compressed(filename, train=sets, valid=sets, test=sets) + logger.info(f"Saved {len(sets)} samples in {filename}") + + +if __name__ == '__main__': + main() diff --git a/generate_svg.py b/generate_svg.py new file mode 100644 index 0000000..d261d36 --- /dev/null +++ b/generate_svg.py @@ -0,0 +1,209 @@ +# import the required libraries +import numpy as np +import time +import random +import pickle +import codecs +import collections +import os +import math +import json +import tensorflow as tf +from six.moves import xrange +import argparse +import logging +from tqdm import tqdm + +# libraries required for visualisation: +from IPython.display import SVG, display +import PIL +from PIL import Image +import matplotlib.pyplot as plt + +# set numpy output to something sensible +np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) + +import svgwrite # conda install -c omnia svgwrite=1.1.6 + + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger('dataset') + +argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test') +argParser.add_argument( + '--dataset_dir', + type=str, + default='./datasets/naam', + ) +argParser.add_argument( + '--model_dir', + type=str, + default='./models/naam', + ) +argParser.add_argument( + '--generated_dir', + type=str, + default='./generated/naam', + ) +argParser.add_argument( + '--verbose', + '-v', + action='store_true', + help='Debug logging' + ) +args = argParser.parse_args() + +if args.verbose: + logger.setLevel(logging.DEBUG) + + + +data_dir = args.dataset_dir +model_dir = args.model_dir + + +tf.logging.info("TensorFlow Version: %s", tf.__version__) + +# import our command line tools +from magenta.models.sketch_rnn.sketch_rnn_train import * +from magenta.models.sketch_rnn.model import * +from magenta.models.sketch_rnn.utils import * +from magenta.models.sketch_rnn.rnn import * + +# little function that displays vector images and saves them to .svg +def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'): + tf.gfile.MakeDirs(os.path.dirname(svg_filename)) + min_x, max_x, min_y, max_y = get_bounds(data, factor) + dims = (50 + max_x - min_x, 50 + max_y - min_y) + dwg = svgwrite.Drawing(svg_filename, size=dims) + # dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white')) + lift_pen = 1 + abs_x = 25 - min_x + abs_y = 25 - min_y + p = "M%s,%s " % (abs_x, abs_y) + command = "m" + for i in xrange(len(data)): + if (lift_pen == 1): + command = "m" + elif (command != "l"): + command = "l" + else: + command = "" + x = float(data[i,0])/factor + y = float(data[i,1])/factor + lift_pen = data[i, 2] + p += command+str(x)+","+str(y)+" " + the_color = "black" + stroke_width = 1 + dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none")) + dwg.save() + # display(SVG(dwg.tostring())) + +# generate a 2D grid of many vector drawings +def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0): + def get_start_and_end(x): + x = np.array(x) + x = x[:, 0:2] + x_start = x[0] + x_end = x.sum(axis=0) + x = x.cumsum(axis=0) + x_max = x.max(axis=0) + x_min = x.min(axis=0) + center_loc = (x_max+x_min)*0.5 + return x_start-center_loc, x_end + x_pos = 0.0 + y_pos = 0.0 + result = [[x_pos, y_pos, 1]] + for sample in s_list: + s = sample[0] + grid_loc = sample[1] + grid_y = grid_loc[0]*grid_space+grid_space*0.5 + grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5 + start_loc, delta_pos = get_start_and_end(s) + + loc_x = start_loc[0] + loc_y = start_loc[1] + new_x_pos = grid_x+loc_x + new_y_pos = grid_y+loc_y + result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0]) + + result += s.tolist() + result[-1][2] = 1 + x_pos = new_x_pos+delta_pos[0] + y_pos = new_y_pos+delta_pos[1] + return np.array(result) + +def load_env_compatible(data_dir, model_dir): + """Loads environment for inference mode, used in jupyter notebook.""" + # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py + # to work with depreciated tf.HParams functionality + model_params = sketch_rnn_model.get_default_hparams() + with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: + data = json.load(f) + fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] + for fix in fix_list: + data[fix] = (data[fix] == 1) + model_params.parse_json(json.dumps(data)) + return load_dataset(data_dir, model_params, inference_mode=True) + +def load_model_compatible(model_dir): + """Loads model for inference mode, used in jupyter notebook.""" + # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py + # to work with depreciated tf.HParams functionality + model_params = sketch_rnn_model.get_default_hparams() + with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: + data = json.load(f) + fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] + for fix in fix_list: + data[fix] = (data[fix] == 1) + model_params.parse_json(json.dumps(data)) + + model_params.batch_size = 1 # only sample one at a time + eval_model_params = sketch_rnn_model.copy_hparams(model_params) + eval_model_params.use_input_dropout = 0 + eval_model_params.use_recurrent_dropout = 0 + eval_model_params.use_output_dropout = 0 + eval_model_params.is_training = 0 + sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) + sample_model_params.max_seq_len = 1 # sample one point at a time + return [model_params, eval_model_params, sample_model_params] + +[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(data_dir, model_dir) +# construct the sketch-rnn model here: +reset_graph() +model = Model(hps_model) +eval_model = Model(eval_hps_model, reuse=True) +sample_model = Model(sample_hps_model, reuse=True) + +sess = tf.InteractiveSession() +sess.run(tf.global_variables_initializer()) + +# loads the weights from checkpoint into our model +load_checkpoint(sess, args.model_dir) + +# We define two convenience functions to encode a stroke into a latent vector, and decode from latent vector to stroke. +def encode(input_strokes): + strokes = to_big_strokes(input_strokes).tolist() + strokes.insert(0, [0, 0, 1, 0, 0]) + seq_len = [len(input_strokes)] + draw_strokes(to_normal_strokes(np.array(strokes))) + return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0] + +def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2, filename=None): + z = None + if z_input is not None: + z = [z_input] + sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z) + strokes = to_normal_strokes(sample_strokes) + if draw_mode: + draw_strokes(strokes, factor, svg_filename = filename) + return strokes + +with tqdm(total=10*50) as pbar: + for i in range(10): + temperature = float(i+1) / 10. + for j in range(50): + filename = os.path.join(args.generated_dir, f"generated{temperature}-{j:03d}.svg") + _ = decode(temperature=temperature, filename=filename) + pbar.update()