# import the required libraries import numpy as np import time import random import pickle import codecs import collections import os import math import json import tensorflow as tf from six.moves import xrange import argparse import logging from tqdm import tqdm # libraries required for visualisation: from IPython.display import SVG, display import PIL from PIL import Image import matplotlib.pyplot as plt # set numpy output to something sensible np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True) import svgwrite # conda install -c omnia svgwrite=1.1.6 logging.basicConfig(level=logging.INFO) logger = logging.getLogger('dataset') argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test') argParser.add_argument( '--dataset_dir', type=str, default='./datasets/naam', ) argParser.add_argument( '--model_dir', type=str, default='./models/naam', ) argParser.add_argument( '--generated_dir', type=str, default='./generated/naam', ) argParser.add_argument( '--verbose', '-v', action='store_true', help='Debug logging' ) args = argParser.parse_args() if args.verbose: logger.setLevel(logging.DEBUG) data_dir = args.dataset_dir model_dir = args.model_dir tf.logging.info("TensorFlow Version: %s", tf.__version__) # import our command line tools from magenta.models.sketch_rnn.sketch_rnn_train import * from magenta.models.sketch_rnn.model import * from magenta.models.sketch_rnn.utils import * from magenta.models.sketch_rnn.rnn import * # little function that displays vector images and saves them to .svg def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'): tf.gfile.MakeDirs(os.path.dirname(svg_filename)) min_x, max_x, min_y, max_y = get_bounds(data, factor) dims = (50 + max_x - min_x, 50 + max_y - min_y) dwg = svgwrite.Drawing(svg_filename, size=dims) # dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white')) lift_pen = 1 abs_x = 25 - min_x abs_y = 25 - min_y p = "M%s,%s " % (abs_x, abs_y) command = "m" for i in xrange(len(data)): if (lift_pen == 1): command = "m" elif (command != "l"): command = "l" else: command = "" x = float(data[i,0])/factor y = float(data[i,1])/factor lift_pen = data[i, 2] p += command+str(x)+","+str(y)+" " the_color = "black" stroke_width = 1 dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none")) dwg.save() # display(SVG(dwg.tostring())) # generate a 2D grid of many vector drawings def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0): def get_start_and_end(x): x = np.array(x) x = x[:, 0:2] x_start = x[0] x_end = x.sum(axis=0) x = x.cumsum(axis=0) x_max = x.max(axis=0) x_min = x.min(axis=0) center_loc = (x_max+x_min)*0.5 return x_start-center_loc, x_end x_pos = 0.0 y_pos = 0.0 result = [[x_pos, y_pos, 1]] for sample in s_list: s = sample[0] grid_loc = sample[1] grid_y = grid_loc[0]*grid_space+grid_space*0.5 grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5 start_loc, delta_pos = get_start_and_end(s) loc_x = start_loc[0] loc_y = start_loc[1] new_x_pos = grid_x+loc_x new_y_pos = grid_y+loc_y result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0]) result += s.tolist() result[-1][2] = 1 x_pos = new_x_pos+delta_pos[0] y_pos = new_y_pos+delta_pos[1] return np.array(result) def load_env_compatible(data_dir, model_dir): """Loads environment for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) return load_dataset(data_dir, model_params, inference_mode=True) def load_model_compatible(model_dir): """Loads model for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params] [train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(data_dir, model_dir) # construct the sketch-rnn model here: reset_graph() model = Model(hps_model) eval_model = Model(eval_hps_model, reuse=True) sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) # loads the weights from checkpoint into our model load_checkpoint(sess, args.model_dir) # We define two convenience functions to encode a stroke into a latent vector, and decode from latent vector to stroke. def encode(input_strokes): strokes = to_big_strokes(input_strokes).tolist() strokes.insert(0, [0, 0, 1, 0, 0]) seq_len = [len(input_strokes)] draw_strokes(to_normal_strokes(np.array(strokes))) return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0] def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2, filename=None): z = None if z_input is not None: z = [z_input] sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z) strokes = to_normal_strokes(sample_strokes) if draw_mode: draw_strokes(strokes, factor, svg_filename = filename) return strokes with tqdm(total=10*50) as pbar: for i in range(10): temperature = float(i+1) / 10. for j in range(50): filename = os.path.join(args.generated_dir, f"generated{temperature}-{j:03d}.svg") _ = decode(temperature=temperature, filename=filename) pbar.update()