# import the required libraries import numpy as np import time import random import pickle import codecs import collections import os import math import json import tensorflow as tf from six.moves import xrange import logging import argparse import svgwrite from tqdm import tqdm import re # import our command line tools from magenta.models.sketch_rnn.sketch_rnn_train import * from magenta.models.sketch_rnn.model import * from magenta.models.sketch_rnn.utils import * from magenta.models.sketch_rnn.rnn import * logging.basicConfig(level=logging.INFO) logger = logging.getLogger('card') argParser = argparse.ArgumentParser(description='Create postcard') argParser.add_argument( '--data_dir', type=str, default='./datasets/naam3', help='' ) argParser.add_argument( '--model_dir', type=str, default='./models/naam3', ) argParser.add_argument( '--output_dir', type=str, default='generated/naam3', ) argParser.add_argument( '--width', type=str, default='90mm', ) argParser.add_argument( '--height', type=str, default='150mm', ) argParser.add_argument( '--rows', type=int, default=13, ) argParser.add_argument( '--columns', type=int, default=5, ) # argParser.add_argument( # '--output_file', # type=str, # default='card.svg', # ) argParser.add_argument( '--verbose', '-v', action='store_true', help='Debug logging' ) args = argParser.parse_args() dataset_baseheight = 10 if args.verbose: logger.setLevel(logging.DEBUG) def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25): lift_pen = 1 min_x, max_x, min_y, max_y = get_bounds(strokes, factor) abs_x = start_x - min_x abs_y = start_y - min_y p = "M%s,%s " % (abs_x, abs_y) # p = "M%s,%s " % (0, 0) command = "m" for i in xrange(len(strokes)): if (lift_pen == 1): command = "m" elif (command != "l"): command = "l" else: command = "" x = float(strokes[i,0])/factor y = float(strokes[i,1])/factor lift_pen = strokes[i, 2] p += f"{command}{x:.5},{y:.5} " the_color = "black" stroke_width = 1 return dwg.path(p).stroke(the_color,stroke_width).fill("none") # little function that displays vector images and saves them to .svg def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'): tf.gfile.MakeDirs(os.path.dirname(svg_filename)) min_x, max_x, min_y, max_y = get_bounds(strokes, factor) dims = (50 + max_x - min_x, 50 + max_y - min_y) dwg = svgwrite.Drawing(svg_filename, size=dims) dwg.add(strokesToPath(dwg, strokes, factor)) dwg.save() def load_env_compatible(data_dir, model_dir): """Loads environment for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) return load_dataset(data_dir, model_params, inference_mode=True) def load_model_compatible(model_dir): """Loads model for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params] # some basic initialisation (done before encode() and decode() as they use these variables) [train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir) # construct the sketch-rnn model here: reset_graph() model = Model(hps_model) eval_model = Model(eval_hps_model, reuse=True) sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) # loads the weights from checkpoint into our model load_checkpoint(sess, args.model_dir) def encode(input_strokes): """ Encode input image into vector """ strokes = to_big_strokes(input_strokes, 614).tolist() strokes.insert(0, [0, 0, 1, 0, 0]) seq_len = [len(input_strokes)] print(seq_len) draw_strokes(to_normal_strokes(np.array(strokes))) print(np.array([strokes]).shape) return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0] def decode(z_input=None, temperature=0.1, factor=0.2): """ Decode vector into strokes (image) """ z = None if z_input is not None: z = [z_input] sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z) strokes = to_normal_strokes(sample_strokes) return strokes dims = (args.width, args.height) width = int(re.findall('\d+',args.width)[0])*10 height = int(re.findall('\d+',args.height)[0])*10 # dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}") item_count = args.rows*args.columns factor = dataset_baseheight/ (height/args.rows) with tqdm(total=item_count) as pbar: for row in range(args.rows): start_y = row * (float(height) / args.rows) row_temp = .1+row * (1./args.rows) for column in range(args.columns): strokes = decode(temperature=row_temp) fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg') draw_strokes(strokes, svg_filename=fn) # current_nr = row * args.columns + column # temp = .01+current_nr * (1./item_count) # start_x = column * (float(width) / args.columns) # path = strokesToPath(dwg, strokes, factor, start_x, start_y) # dwg.add(path) pbar.update() # dwg.save()