# import the required libraries import numpy as np import time import random import pickle import codecs import collections import os import math import json import tensorflow as tf from six.moves import xrange import logging import argparse import svgwrite from tqdm import tqdm import re import glob import math # import our command line tools from magenta.models.sketch_rnn.sketch_rnn_train import * from magenta.models.sketch_rnn.model import * from magenta.models.sketch_rnn.utils import * from magenta.models.sketch_rnn.rnn import * logging.basicConfig(level=logging.INFO) logger = logging.getLogger('card') argParser = argparse.ArgumentParser(description='Create postcard') argParser.add_argument( '--data_dir', type=str, default='./datasets/naam4', help='' ) argParser.add_argument( '--model_dir', type=str, default='./models/naam4', ) argParser.add_argument( '--output_file', type=str, default='generated/naam4.svg', ) argParser.add_argument( '--target_sample', type=int, default=202, ) argParser.add_argument( '--width', type=str, default='100mm', ) argParser.add_argument( '--height', type=str, default='190mm', ) argParser.add_argument( '--rows', type=int, default=13, ) argParser.add_argument( '--columns', type=int, default=5, ) argParser.add_argument( '--max_checkpoint_factor', type=float, default=1., help='If there are too many checkpoints that create smooth outcomes, limit that with this factor' ) # argParser.add_argument( # '--output_file', # type=str, # default='card.svg', # ) argParser.add_argument( '--verbose', '-v', action='store_true', help='Debug logging' ) args = argParser.parse_args() if args.verbose: logger.setLevel(logging.DEBUG) def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height): lift_pen = 1 min_x, max_x, min_y, max_y = get_bounds(strokes, 1) factor_x = max_width / (max_x-min_x) factor_y = max_height / (max_y-min_y) # assuming both < 1 factor = factor_y if factor_y < factor_x else factor_x abs_x = start_x - min_x * factor abs_y = start_y - min_y * factor p = "M%s,%s " % (abs_x, abs_y) # p = "M%s,%s " % (0, 0) command = "m" for i in xrange(len(strokes)): if (lift_pen == 1): command = "m" elif (command != "l"): command = "l" else: command = "" x = float(strokes[i,0])*factor y = float(strokes[i,1])*factor lift_pen = strokes[i, 2] p += f"{command}{x:.5},{y:.5} " the_color = "black" stroke_width = 1 return dwg.path(p).stroke(the_color,stroke_width).fill("none") # little function that displays vector images and saves them to .svg def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'): tf.gfile.MakeDirs(os.path.dirname(svg_filename)) min_x, max_x, min_y, max_y = get_bounds(strokes, factor) dims = (50 + max_x - min_x, 50 + max_y - min_y) dwg = svgwrite.Drawing(svg_filename, size=dims) dwg.add(strokesToPath(dwg, strokes, factor)) dwg.save() def load_env_compatible(data_dir, model_dir): """Loads environment for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) return load_dataset(data_dir, model_params, inference_mode=True) def load_model_compatible(model_dir): """Loads model for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params] # some basic initialisation (done before encode() and decode() as they use these variables) [train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir) # construct the sketch-rnn model here: reset_graph() model = Model(hps_model) eval_model = Model(eval_hps_model, reuse=True) sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) # loads the (latest) weights from checkpoint into our model load_checkpoint(sess, args.model_dir) def encode(input_strokes): """ Encode input image into vector """ strokes = to_big_strokes(input_strokes, 614).tolist() strokes.insert(0, [0, 0, 1, 0, 0]) seq_len = [len(input_strokes)] print(seq_len) # draw_strokes(to_normal_strokes(np.array(strokes))) print(np.array([strokes]).shape) return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0] def decode(z_input=None, temperature=0.1, factor=0.2): """ Decode vector into strokes (image) """ z = None if z_input is not None: z = [z_input] sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z) strokes = to_normal_strokes(sample_strokes) return strokes def getCheckpoints(model_dir): files = glob.glob(os.path.join(model_dir,'vector-*.meta')) checkpoints = [] for file in files: checkpoints.append(int(file.split('-')[-1][:-5])) return sorted(checkpoints) currentCheckpoint = None def loadCheckpoint(model_dir, nr): if nr == currentCheckpoint: return # loads the intermediate weights from checkpoint into our model saver = tf.train.Saver(tf.global_variables()) # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path) saver.restore(sess, os.path.join(model_dir, f"vector-{nr}")) dims = (args.width, args.height) width = int(re.findall('\d+',args.width)[0])*10 height = int(re.findall('\d+',args.height)[0])*10 padding = 20 dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}") checkpoints = getCheckpoints(args.model_dir) item_count = args.rows*args.columns # factor = dataset_baseheight/ (height/args.rows) # initialize z target_stroke = test_set.strokes[args.target_sample] target_z = encode(target_stroke) max_width = (width-(padding*(args.columns+1))) / args.columns max_height = (height-(padding*(args.rows+1))) / args.rows with tqdm(total=item_count) as pbar: for row in range(args.rows): #find the top left point for the strokes min_y = row * (float(height-2*padding) / args.rows) + padding for column in range(args.columns): min_x = column * (float(width-2*padding) / args.columns) + padding item = row*args.columns + column checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints)) checkpoint = checkpoints[checkpoint_idx] loadCheckpoint(args.model_dir, checkpoint) strokes = decode(target_z, temperature=1) path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) dwg.add(path) # draw_strokes(strokes, svg_filename=fn) # for row in range(args.rows): # start_y = row * (float(height) / args.rows) # row_temp = .1+row * (1./args.rows) # for column in range(args.columns): # strokes = decode(temperature=row_temp) # fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg') # draw_strokes(strokes, svg_filename=fn) # current_nr = row * args.columns + column # temp = .01+current_nr * (1./item_count) # start_x = column * (float(width) / args.columns) # path = strokesToPath(dwg, strokes, factor, start_x, start_y) # dwg.add(path) pbar.update() dwg.save()