210 lines
6.9 KiB
Python
210 lines
6.9 KiB
Python
|
# import the required libraries
|
||
|
import numpy as np
|
||
|
import time
|
||
|
import random
|
||
|
import pickle
|
||
|
import codecs
|
||
|
import collections
|
||
|
import os
|
||
|
import math
|
||
|
import json
|
||
|
import tensorflow as tf
|
||
|
from six.moves import xrange
|
||
|
import argparse
|
||
|
import logging
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
# libraries required for visualisation:
|
||
|
from IPython.display import SVG, display
|
||
|
import PIL
|
||
|
from PIL import Image
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
# set numpy output to something sensible
|
||
|
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
|
||
|
|
||
|
import svgwrite # conda install -c omnia svgwrite=1.1.6
|
||
|
|
||
|
|
||
|
|
||
|
logging.basicConfig(level=logging.INFO)
|
||
|
logger = logging.getLogger('dataset')
|
||
|
|
||
|
argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test')
|
||
|
argParser.add_argument(
|
||
|
'--dataset_dir',
|
||
|
type=str,
|
||
|
default='./datasets/naam',
|
||
|
)
|
||
|
argParser.add_argument(
|
||
|
'--model_dir',
|
||
|
type=str,
|
||
|
default='./models/naam',
|
||
|
)
|
||
|
argParser.add_argument(
|
||
|
'--generated_dir',
|
||
|
type=str,
|
||
|
default='./generated/naam',
|
||
|
)
|
||
|
argParser.add_argument(
|
||
|
'--verbose',
|
||
|
'-v',
|
||
|
action='store_true',
|
||
|
help='Debug logging'
|
||
|
)
|
||
|
args = argParser.parse_args()
|
||
|
|
||
|
if args.verbose:
|
||
|
logger.setLevel(logging.DEBUG)
|
||
|
|
||
|
|
||
|
|
||
|
data_dir = args.dataset_dir
|
||
|
model_dir = args.model_dir
|
||
|
|
||
|
|
||
|
tf.logging.info("TensorFlow Version: %s", tf.__version__)
|
||
|
|
||
|
# import our command line tools
|
||
|
from magenta.models.sketch_rnn.sketch_rnn_train import *
|
||
|
from magenta.models.sketch_rnn.model import *
|
||
|
from magenta.models.sketch_rnn.utils import *
|
||
|
from magenta.models.sketch_rnn.rnn import *
|
||
|
|
||
|
# little function that displays vector images and saves them to .svg
|
||
|
def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
|
||
|
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
|
||
|
min_x, max_x, min_y, max_y = get_bounds(data, factor)
|
||
|
dims = (50 + max_x - min_x, 50 + max_y - min_y)
|
||
|
dwg = svgwrite.Drawing(svg_filename, size=dims)
|
||
|
# dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
|
||
|
lift_pen = 1
|
||
|
abs_x = 25 - min_x
|
||
|
abs_y = 25 - min_y
|
||
|
p = "M%s,%s " % (abs_x, abs_y)
|
||
|
command = "m"
|
||
|
for i in xrange(len(data)):
|
||
|
if (lift_pen == 1):
|
||
|
command = "m"
|
||
|
elif (command != "l"):
|
||
|
command = "l"
|
||
|
else:
|
||
|
command = ""
|
||
|
x = float(data[i,0])/factor
|
||
|
y = float(data[i,1])/factor
|
||
|
lift_pen = data[i, 2]
|
||
|
p += command+str(x)+","+str(y)+" "
|
||
|
the_color = "black"
|
||
|
stroke_width = 1
|
||
|
dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
|
||
|
dwg.save()
|
||
|
# display(SVG(dwg.tostring()))
|
||
|
|
||
|
# generate a 2D grid of many vector drawings
|
||
|
def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):
|
||
|
def get_start_and_end(x):
|
||
|
x = np.array(x)
|
||
|
x = x[:, 0:2]
|
||
|
x_start = x[0]
|
||
|
x_end = x.sum(axis=0)
|
||
|
x = x.cumsum(axis=0)
|
||
|
x_max = x.max(axis=0)
|
||
|
x_min = x.min(axis=0)
|
||
|
center_loc = (x_max+x_min)*0.5
|
||
|
return x_start-center_loc, x_end
|
||
|
x_pos = 0.0
|
||
|
y_pos = 0.0
|
||
|
result = [[x_pos, y_pos, 1]]
|
||
|
for sample in s_list:
|
||
|
s = sample[0]
|
||
|
grid_loc = sample[1]
|
||
|
grid_y = grid_loc[0]*grid_space+grid_space*0.5
|
||
|
grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5
|
||
|
start_loc, delta_pos = get_start_and_end(s)
|
||
|
|
||
|
loc_x = start_loc[0]
|
||
|
loc_y = start_loc[1]
|
||
|
new_x_pos = grid_x+loc_x
|
||
|
new_y_pos = grid_y+loc_y
|
||
|
result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])
|
||
|
|
||
|
result += s.tolist()
|
||
|
result[-1][2] = 1
|
||
|
x_pos = new_x_pos+delta_pos[0]
|
||
|
y_pos = new_y_pos+delta_pos[1]
|
||
|
return np.array(result)
|
||
|
|
||
|
def load_env_compatible(data_dir, model_dir):
|
||
|
"""Loads environment for inference mode, used in jupyter notebook."""
|
||
|
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
||
|
# to work with depreciated tf.HParams functionality
|
||
|
model_params = sketch_rnn_model.get_default_hparams()
|
||
|
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
||
|
data = json.load(f)
|
||
|
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
||
|
for fix in fix_list:
|
||
|
data[fix] = (data[fix] == 1)
|
||
|
model_params.parse_json(json.dumps(data))
|
||
|
return load_dataset(data_dir, model_params, inference_mode=True)
|
||
|
|
||
|
def load_model_compatible(model_dir):
|
||
|
"""Loads model for inference mode, used in jupyter notebook."""
|
||
|
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
||
|
# to work with depreciated tf.HParams functionality
|
||
|
model_params = sketch_rnn_model.get_default_hparams()
|
||
|
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
||
|
data = json.load(f)
|
||
|
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
||
|
for fix in fix_list:
|
||
|
data[fix] = (data[fix] == 1)
|
||
|
model_params.parse_json(json.dumps(data))
|
||
|
|
||
|
model_params.batch_size = 1 # only sample one at a time
|
||
|
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
|
||
|
eval_model_params.use_input_dropout = 0
|
||
|
eval_model_params.use_recurrent_dropout = 0
|
||
|
eval_model_params.use_output_dropout = 0
|
||
|
eval_model_params.is_training = 0
|
||
|
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
|
||
|
sample_model_params.max_seq_len = 1 # sample one point at a time
|
||
|
return [model_params, eval_model_params, sample_model_params]
|
||
|
|
||
|
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(data_dir, model_dir)
|
||
|
# construct the sketch-rnn model here:
|
||
|
reset_graph()
|
||
|
model = Model(hps_model)
|
||
|
eval_model = Model(eval_hps_model, reuse=True)
|
||
|
sample_model = Model(sample_hps_model, reuse=True)
|
||
|
|
||
|
sess = tf.InteractiveSession()
|
||
|
sess.run(tf.global_variables_initializer())
|
||
|
|
||
|
# loads the weights from checkpoint into our model
|
||
|
load_checkpoint(sess, args.model_dir)
|
||
|
|
||
|
# We define two convenience functions to encode a stroke into a latent vector, and decode from latent vector to stroke.
|
||
|
def encode(input_strokes):
|
||
|
strokes = to_big_strokes(input_strokes).tolist()
|
||
|
strokes.insert(0, [0, 0, 1, 0, 0])
|
||
|
seq_len = [len(input_strokes)]
|
||
|
draw_strokes(to_normal_strokes(np.array(strokes)))
|
||
|
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
|
||
|
|
||
|
def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2, filename=None):
|
||
|
z = None
|
||
|
if z_input is not None:
|
||
|
z = [z_input]
|
||
|
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
|
||
|
strokes = to_normal_strokes(sample_strokes)
|
||
|
if draw_mode:
|
||
|
draw_strokes(strokes, factor, svg_filename = filename)
|
||
|
return strokes
|
||
|
|
||
|
with tqdm(total=10*50) as pbar:
|
||
|
for i in range(10):
|
||
|
temperature = float(i+1) / 10.
|
||
|
for j in range(50):
|
||
|
filename = os.path.join(args.generated_dir, f"generated{temperature}-{j:03d}.svg")
|
||
|
_ = decode(temperature=temperature, filename=filename)
|
||
|
pbar.update()
|