birthcard/Sketch_RNN.ipynb

934 KiB

In this notebook, we will show how to load pre-trained models and draw things with sketch-rnn

In [1]:
# import the required libraries
import numpy as np
import time
import random
import pickle
import codecs
import collections
import os
import math
import json
import tensorflow as tf
from six.moves import xrange
/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
/home/ruben/.local/lib/python3.7/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
In [2]:
# libraries required for visualisation:
from IPython.display import SVG, display
import PIL
from PIL import Image
import matplotlib.pyplot as plt

# set numpy output to something sensible
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
In [3]:
# !pip install -qU svgwrite
In [4]:
import svgwrite # conda install -c omnia svgwrite=1.1.6
In [5]:
tf.logging.info("TensorFlow Version: %s", tf.__version__)
In [6]:
# !pip install -q magenta
In [7]:
# import our command line tools
from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *
WARNING: Logging before flag parsing goes to stderr.
W0826 11:57:16.532441 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/pipelines/statistics.py:132: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.

/home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/numba/errors.py:131: UserWarning: Insufficiently recent colorama version found. Numba requires colorama >= 0.3.9
  warnings.warn(msg)
W0826 11:57:17.155105 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/music/note_sequence_io.py:60: The name tf.python_io.TFRecordWriter is deprecated. Please use tf.io.TFRecordWriter instead.

W0826 11:57:17.477846 140006013421376 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

W0826 11:57:17.478699 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.set_verbosity is deprecated. Please use tf.compat.v1.logging.set_verbosity instead.

W0826 11:57:17.479223 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:34: The name tf.logging.INFO is deprecated. Please use tf.compat.v1.logging.INFO instead.

In [8]:
# little function that displays vector images and saves them to .svg
def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
  tf.gfile.MakeDirs(os.path.dirname(svg_filename))
  min_x, max_x, min_y, max_y = get_bounds(data, factor)
  dims = (50 + max_x - min_x, 50 + max_y - min_y)
  dwg = svgwrite.Drawing(svg_filename, size=dims)
  dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
  lift_pen = 1
  abs_x = 25 - min_x 
  abs_y = 25 - min_y
  p = "M%s,%s " % (abs_x, abs_y)
  command = "m"
  for i in xrange(len(data)):
    if (lift_pen == 1):
      command = "m"
    elif (command != "l"):
      command = "l"
    else:
      command = ""
    x = float(data[i,0])/factor
    y = float(data[i,1])/factor
    lift_pen = data[i, 2]
    p += command+str(x)+","+str(y)+" "
  the_color = "black"
  stroke_width = 1
  dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
  dwg.save()
  display(SVG(dwg.tostring()))

# generate a 2D grid of many vector drawings
def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):
  def get_start_and_end(x):
    x = np.array(x)
    x = x[:, 0:2]
    x_start = x[0]
    x_end = x.sum(axis=0)
    x = x.cumsum(axis=0)
    x_max = x.max(axis=0)
    x_min = x.min(axis=0)
    center_loc = (x_max+x_min)*0.5
    return x_start-center_loc, x_end
  x_pos = 0.0
  y_pos = 0.0
  result = [[x_pos, y_pos, 1]]
  for sample in s_list:
    s = sample[0]
    grid_loc = sample[1]
    grid_y = grid_loc[0]*grid_space+grid_space*0.5
    grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5
    start_loc, delta_pos = get_start_and_end(s)

    loc_x = start_loc[0]
    loc_y = start_loc[1]
    new_x_pos = grid_x+loc_x
    new_y_pos = grid_y+loc_y
    result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])

    result += s.tolist()
    result[-1][2] = 1
    x_pos = new_x_pos+delta_pos[0]
    y_pos = new_y_pos+delta_pos[1]
  return np.array(result)

define the path of the model you want to load, and also the path of the dataset

In [14]:
data_dir = 'datasets/naam5'
model_dir = 'models/naam5'
In [15]:
# download_pretrained_models(models_root_dir=models_root_dir)
In [16]:
def load_env_compatible(data_dir, model_dir):
  """Loads environment for inference mode, used in jupyter notebook."""
  # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
  # to work with depreciated tf.HParams functionality
  model_params = sketch_rnn_model.get_default_hparams()
  with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
    data = json.load(f)
  fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
  for fix in fix_list:
    data[fix] = (data[fix] == 1)
  model_params.parse_json(json.dumps(data))
  return load_dataset(data_dir, model_params, inference_mode=True)

def load_model_compatible(model_dir):
  """Loads model for inference mode, used in jupyter notebook."""
  # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
  # to work with depreciated tf.HParams functionality
  model_params = sketch_rnn_model.get_default_hparams()
  with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
    data = json.load(f)
  fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
  for fix in fix_list:
    data[fix] = (data[fix] == 1)
  model_params.parse_json(json.dumps(data))

  model_params.batch_size = 1  # only sample one at a time
  eval_model_params = sketch_rnn_model.copy_hparams(model_params)
  eval_model_params.use_input_dropout = 0
  eval_model_params.use_recurrent_dropout = 0
  eval_model_params.use_output_dropout = 0
  eval_model_params.is_training = 0
  sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
  sample_model_params.max_seq_len = 1  # sample one point at a time
  return [model_params, eval_model_params, sample_model_params]
In [17]:
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(data_dir, model_dir)
I0826 11:57:46.998881 140006013421376 sketch_rnn_train.py:142] Loaded 161/161/161 from diede.npz
I0826 11:57:47.254209 140006013421376 sketch_rnn_train.py:142] Loaded 100/100/100 from blokletters.npz
I0826 11:57:47.277675 140006013421376 sketch_rnn_train.py:159] Dataset combined: 783 (261/261/261), avg len 313
I0826 11:57:47.278856 140006013421376 sketch_rnn_train.py:166] model_params.max_seq_len 614.
I0826 11:57:47.425982 140006013421376 sketch_rnn_train.py:209] normalizing_scale_factor 34.8942.
total images <= max_seq_len is 261
total images <= max_seq_len is 261
total images <= max_seq_len is 261
In [18]:
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
W0826 11:57:49.141304 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:62: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.

W0826 11:57:49.142381 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/sketch_rnn_train.py:65: The name tf.reset_default_graph is deprecated. Please use tf.compat.v1.reset_default_graph instead.

W0826 11:57:49.143638 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:81: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

I0826 11:57:49.144913 140006013421376 model.py:87] Model using gpu.
I0826 11:57:49.148379 140006013421376 model.py:175] Input dropout mode = False.
I0826 11:57:49.148915 140006013421376 model.py:176] Output dropout mode = False.
I0826 11:57:49.149318 140006013421376 model.py:177] Recurrent dropout mode = False.
W0826 11:57:49.149732 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:190: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0826 11:57:49.161299 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:242: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W0826 11:57:49.161976 140006013421376 deprecation.py:506] From /home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0826 11:57:49.173728 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:253: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API
W0826 11:57:49.513816 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:256: The name tf.nn.xw_plus_b is deprecated. Please use tf.compat.v1.nn.xw_plus_b instead.

W0826 11:57:49.529942 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:266: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
W0826 11:57:49.544464 140006013421376 deprecation.py:506] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:285: calling reduce_sum_v1 (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.
Instructions for updating:
keep_dims is deprecated, use keepdims instead
W0826 11:57:49.554083 140006013421376 deprecation.py:323] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:295: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See `tf.nn.softmax_cross_entropy_with_logits_v2`.

W0826 11:57:49.582807 140006013421376 deprecation_wrapper.py:119] From /home/ruben/Documents/Geboortekaartje/sketch_rnn/venv/lib/python3.7/site-packages/magenta/models/sketch_rnn/model.py:351: The name tf.train.AdamOptimizer is deprecated. Please use tf.compat.v1.train.AdamOptimizer instead.

I0826 11:57:50.112523 140006013421376 model.py:87] Model using gpu.
I0826 11:57:50.113037 140006013421376 model.py:175] Input dropout mode = 0.
I0826 11:57:50.113482 140006013421376 model.py:176] Output dropout mode = 0.
I0826 11:57:50.113979 140006013421376 model.py:177] Recurrent dropout mode = 0.
I0826 11:57:50.264065 140006013421376 model.py:87] Model using gpu.
I0826 11:57:50.264573 140006013421376 model.py:175] Input dropout mode = 0.
I0826 11:57:50.264990 140006013421376 model.py:176] Output dropout mode = 0.
I0826 11:57:50.265478 140006013421376 model.py:177] Recurrent dropout mode = 0.
In [19]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
In [20]:
# loads the weights from checkpoint into our model
load_checkpoint(sess, model_dir)
# saver = tf.train.Saver(tf.global_variables())
# ckpt = tf.train.get_checkpoint_state(model_dir)
# print(int(ckpt.model_checkpoint_path.split("-")[-1])/100)
# # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
# saver.restore(sess, "models/naam4/vector-4100")
W0826 11:57:52.587875 140006013421376 deprecation.py:323] From /home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
29.0
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-20-787fc7b1ddc6> in <module>()
      5 print(int(ckpt.model_checkpoint_path.split("-")[-1])/100)
      6 # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
----> 7 saver.restore(sess, "models/naam4/vector-4100")

/home/ruben/.local/lib/python3.7/site-packages/tensorflow/python/training/saver.py in restore(self, sess, save_path)
   1276     if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
   1277       raise ValueError("The passed save_path is not a valid checkpoint: " +
-> 1278                        compat.as_text(save_path))
   1279 
   1280     logging.info("Restoring parameters from %s", compat.as_text(save_path))

ValueError: The passed save_path is not a valid checkpoint: models/naam4/vector-4100

We define two convenience functions to encode a stroke into a latent vector, and decode from latent vector to stroke.

In [159]:
def encode(input_strokes):
  strokes = to_big_strokes(input_strokes, 614).tolist()
  strokes.insert(0, [0, 0, 1, 0, 0])
  seq_len = [len(input_strokes)]
  print(seq_len)
  draw_strokes(to_normal_strokes(np.array(strokes)))
  print(np.array([strokes]).shape)
  return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
In [160]:
def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2):
  z = None
  if z_input is not None:
    z = [z_input]
  sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
  strokes = to_normal_strokes(sample_strokes)
  if draw_mode:
    draw_strokes(strokes, factor)
  return strokes
In [201]:
# get a sample drawing from the test set, and render it to .svg
stroke = test_set.random_sample()
stroke=test_set.strokes[252]
draw_strokes(stroke)

Let's try to encode the sample stroke into latent vector $z$

In [190]:
z = encode(stroke)
[244]
(1, 615, 5)
In [197]:
_ = decode(z,temperature=1) # convert z back to drawing at temperature of 0.8

Create generated grid at various temperatures from 0.1 to 1.0

In [102]:
for i in range(10):
    temp = .1 + i*.1
    print(temp)
    for j in range(5):
        stroke = decode(draw_mode=False, temperature=temp)
        draw_strokes(stroke)
        
    
# stroke_grid = make_grid_svg(stroke_list)
# draw_strokes(stroke_grid)
0.1
0.2
0.30000000000000004
0.4
0.5
0.6
0.7000000000000001
0.8
0.9
1.0

Latent Space Interpolation Example between $z_0$ and $z_1$

In [83]:
# get a sample drawing from the test set, and render it to .svg
z0 = z
_ = decode(z0)
In [66]:
stroke = test_set.random_sample()
z1 = encode(stroke)
_ = decode(z1)
[425]
(1, 615, 5)

Now we interpolate between sheep $z_0$ and sheep $z_1$

In [84]:
z_list = [] # interpolate spherically between z0 and z1
N = 10
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z0, z1, t))
In [85]:
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(z_list[i], draw_mode=True), [0, i]])
In [ ]:
stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid)

Let's load the Flamingo Model, and try Unconditional (Decoder-Only) Generation

In [ ]:
model_dir = '/tmp/sketch_rnn/models/flamingo/lstm_uncond'
In [ ]:
[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)
In [ ]:
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
In [ ]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
In [ ]:
# loads the weights from checkpoint into our model
load_checkpoint(sess, model_dir)
In [ ]:
# randomly unconditionally generate 10 examples
N = 10
reconstructions = []
for i in range(N):
  reconstructions.append([decode(temperature=0.5, draw_mode=False), [0, i]])
In [ ]:
stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid)

Let's load the owl model, and generate two sketches using two random IID gaussian latent vectors

In [ ]:
model_dir = '/tmp/sketch_rnn/models/owl/lstm'
In [ ]:
[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess, model_dir)
In [ ]:
z_0 = np.random.randn(eval_model.hps.z_size)
_ = decode(z_0)
In [ ]:
z_1 = np.random.randn(eval_model.hps.z_size)
_ = decode(z_1)

Let's interpolate between the two owls $z_0$ and $z_1$

In [ ]:
z_list = [] # interpolate spherically between z_0 and z_1
N = 10
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z_0, z_1, t))
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(z_list[i], draw_mode=False, temperature=0.1), [0, i]])
In [ ]:
stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid)

Let's load the model trained on both cats and buses! catbus!

In [ ]:
model_dir = '/tmp/sketch_rnn/models/catbus/lstm'
In [ ]:
[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess, model_dir)
In [ ]:
z_1 = np.random.randn(eval_model.hps.z_size)
z_1 = z
_ = decode(z_1)
In [ ]:
z_0 = np.random.randn(eval_model.hps.z_size)
_ = decode(z_0)

Let's interpolate between a cat and a bus!!!

In [ ]:
z_list = [] # interpolate spherically between z_1 and z_0
N = 50
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z_0, z_1, t))
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(z_list[i], draw_mode=False, temperature=0.15), [0, i]])
In [ ]:
stroke_grid = make_grid_svg(reconstructions)
draw_strokes(stroke_grid)

Why stop here? Let's load the model trained on both elephants and pigs!!!

In [ ]:
model_dir = '/tmp/sketch_rnn/models/elephantpig/lstm'
In [ ]:
[hps_model, eval_hps_model, sample_hps_model] = load_model_compatible(model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess, model_dir)
In [ ]:
z_0 = np.random.randn(eval_model.hps.z_size)
_ = decode(z_0)
In [ ]:
z_1 = np.random.randn(eval_model.hps.z_size)
_ = decode(z_1)

Tribute to an episode of South Park: The interpolation between an Elephant and a Pig

In [ ]:
z_list = [] # interpolate spherically between z_1 and z_0
N = 10
for t in np.linspace(0, 1, N):
  z_list.append(slerp(z_0, z_1, t))
# for every latent vector in z_list, sample a vector image
reconstructions = []
for i in range(N):
  reconstructions.append([decode(z_list[i], draw_mode=False, temperature=0.15), [0, i]])
In [ ]:
stroke_grid = make_grid_svg(reconstructions, grid_space_x=25.0)
In [ ]:
draw_strokes(stroke_grid, factor=0.3)
In [ ]: