# import the required libraries import numpy as np import time import random import pickle import codecs import collections import os import math import json import tensorflow as tf from six.moves import xrange import logging import argparse import svgwrite from tqdm import tqdm import re import glob import math from svgwrite.extensions import Inkscape # import our command line tools from magenta.models.sketch_rnn.sketch_rnn_train import * from magenta.models.sketch_rnn.model import * from magenta.models.sketch_rnn.utils import * from magenta.models.sketch_rnn.rnn import * logging.basicConfig(level=logging.INFO) logger = logging.getLogger('card') argParser = argparse.ArgumentParser(description='Create postcard') argParser.add_argument( '--data_dir', type=str, default='./datasets/naam4', help='' ) argParser.add_argument( '--model_dir', type=str, default='./models/naam4', ) argParser.add_argument( '--output_file', type=str, default='generated/naam4.svg', ) argParser.add_argument( '--target_sample', type=int, default=202, ) argParser.add_argument( '--width', type=str, default='99mm', ) argParser.add_argument( '--height', type=str, default='190mm', ) argParser.add_argument( '--rows', type=int, default=13, ) argParser.add_argument( '--columns', type=int, default=5, ) argParser.add_argument( '--column_padding', type=int, default=10, ) argParser.add_argument( '--page_margin', type=int, default=70, ) argParser.add_argument( '--max_checkpoint_factor', type=float, default=1., help='If there are too many checkpoints that create smooth outcomes, limit that with this factor' ) argParser.add_argument( '--split_paths', action='store_true', help='If a stroke contains multiple steps/paths, split these into separate paths/strokes' ) argParser.add_argument( '--nr_of_paths', type=int, default=3, help='If split_paths is given, this number defines the number of groups the paths should be split over' ) argParser.add_argument( '--last_is_target', action='store_true', help='Last item not generated, but as the given target' ) argParser.add_argument( '--last_in_group', action='store_true', help='If set, put the last rendition into a separate group' ) argParser.add_argument( '--create_grid', action='store_true', help='Create a grid with cutting lines' ) argParser.add_argument( '--grid_width', type=int, default=3, help='Grid items x' ) argParser.add_argument( '--grid_height', type=int, default=2, help='Grid items y' ) argParser.add_argument( '--verbose', '-v', action='store_true', help='Debug logging' ) args = argParser.parse_args() if args.verbose: logger.setLevel(logging.DEBUG) def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height): lift_pen = 1 min_x, max_x, min_y, max_y = get_bounds(strokes, 1) factor_x = max_width / (max_x-min_x) factor_y = max_height / (max_y-min_y) # assuming both < 1 factor = factor_y if factor_y < factor_x else factor_x abs_x = start_x - min_x * factor abs_y = start_y - min_y * factor p = "M%s,%s " % (abs_x, abs_y) # p = "M%s,%s " % (0, 0) command = "m" for i in xrange(len(strokes)): if (lift_pen == 1): command = "m" elif (command != "l"): command = "l" else: command = "" x = float(strokes[i,0])*factor y = float(strokes[i,1])*factor lift_pen = strokes[i, 2] p += f"{command}{x:.5},{y:.5} " the_color = "black" stroke_width = 1 return dwg.path(p).stroke(the_color,stroke_width).fill("none") def splitStrokes(strokes): """ turn [[x,y,0],[x,y,1],[x,y,0]] into [[[x,y,0],[x,y,1]], [[x,y,0]]] """ subStrokes = [] for stroke in strokes: subStrokes.append(stroke) if stroke[2] == 1 and subStrokes: yield subStrokes subStrokes = [] if subStrokes: yield subStrokes def strokesToSplitPaths(dwg, strokes, start_x, start_y, max_width, max_height): """ Stroke not to a single path but each lift pen to a next path """ lift_pen = 1 min_x, max_x, min_y, max_y = get_bounds(strokes, 1) factor_x = max_width / (max_x-min_x) factor_y = max_height / (max_y-min_y) # assuming both < 1 factor = factor_y if factor_y < factor_x else factor_x abs_x = start_x - min_x * factor abs_y = start_y - min_y * factor paths = [] for path in splitStrokes(strokes): if len(path)<2: logger.warning(f"Too few strokes to create path: {path}") continue for i in xrange(len(path)): # i += 1 #first item is the move, which we already got above x = float(path[i][0])*factor y = float(path[i][1])*factor abs_x += x abs_y += y if i == 0: p = "M%s,%s l" % (abs_x, abs_y) else: p += f"{x:.5},{y:.5} " the_color = "black" stroke_width = 1 paths.append(dwg.path(p).stroke(the_color,stroke_width).fill("none")) return paths # little function that displays vector images and saves them to .svg def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'): tf.gfile.MakeDirs(os.path.dirname(svg_filename)) min_x, max_x, min_y, max_y = get_bounds(strokes, factor) dims = (50 + max_x - min_x, 50 + max_y - min_y) dwg = svgwrite.Drawing(svg_filename, size=dims) dwg.add(strokesToPath(dwg, strokes, factor)) dwg.save() def load_env_compatible(data_dir, model_dir): """Loads environment for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) return load_dataset(data_dir, model_params, inference_mode=True) def load_model_compatible(model_dir): """Loads model for inference mode, used in jupyter notebook.""" # modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py # to work with depreciated tf.HParams functionality model_params = sketch_rnn_model.get_default_hparams() with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f: data = json.load(f) fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout'] for fix in fix_list: data[fix] = (data[fix] == 1) model_params.parse_json(json.dumps(data)) model_params.batch_size = 1 # only sample one at a time eval_model_params = sketch_rnn_model.copy_hparams(model_params) eval_model_params.use_input_dropout = 0 eval_model_params.use_recurrent_dropout = 0 eval_model_params.use_output_dropout = 0 eval_model_params.is_training = 0 sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params) sample_model_params.max_seq_len = 1 # sample one point at a time return [model_params, eval_model_params, sample_model_params] # some basic initialisation (done before encode() and decode() as they use these variables) [train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir) # construct the sketch-rnn model here: reset_graph() model = Model(hps_model) eval_model = Model(eval_hps_model, reuse=True) sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) # loads the (latest) weights from checkpoint into our model load_checkpoint(sess, args.model_dir) def encode(input_strokes): """ Encode input image into vector """ strokes = to_big_strokes(input_strokes, 614).tolist() strokes.insert(0, [0, 0, 1, 0, 0]) seq_len = [len(input_strokes)] print(seq_len) # draw_strokes(to_normal_strokes(np.array(strokes))) print(np.array([strokes]).shape) return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0] def decode(z_input=None, temperature=0.1, factor=0.2): """ Decode vector into strokes (image) """ z = None if z_input is not None: z = [z_input] sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z) strokes = to_normal_strokes(sample_strokes) return strokes def getCheckpoints(model_dir): files = glob.glob(os.path.join(model_dir,'vector-*.meta')) checkpoints = [] for file in files: checkpoints.append(int(file.split('-')[-1][:-5])) return sorted(checkpoints) currentCheckpoint = None def loadCheckpoint(model_dir, nr): if nr == currentCheckpoint: return # loads the intermediate weights from checkpoint into our model saver = tf.train.Saver(tf.global_variables()) # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path) saver.restore(sess, os.path.join(model_dir, f"vector-{nr}")) width = int(re.findall('\d+',args.width)[0])*10 height = int(re.findall('\d+',args.height)[0])*10 grid_height = args.grid_height if args.create_grid else 1 grid_width = args.grid_width if args.create_grid else 1 # Override given dimension with grid info page_height = width/10*grid_width page_width = height/10*grid_height dims = (f"{page_height}mm", f"{page_width}mm") # padding = 20 dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width*grid_width} {height*grid_height}") inkscapeDwg = Inkscape(dwg) requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0) dwgGroups = [inkscapeDwg.layer(label=f"g{i}") for i in range(requiredGroups)] for group in dwgGroups: dwg.add(group) checkpoints = getCheckpoints(args.model_dir) items_per_page = args.rows*args.columns item_count = items_per_page*grid_width*grid_height logger.info(f"Checkpoints: {checkpoints}, factor: {args.max_checkpoint_factor}") max_checkpoint_idx = math.floor(args.max_checkpoint_factor * len(checkpoints)-1) logger.info(f"Max chkpt: {max_checkpoint_idx}: {checkpoints[max_checkpoint_idx]}") # factor = dataset_baseheight/ (height/args.rows) # initialize z target_stroke = test_set.strokes[args.target_sample] target_z = encode(target_stroke) max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1))) / args.columns max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows with tqdm(total=item_count) as pbar: for grid_pos_x in range(grid_width): grid_x = grid_pos_x * width for grid_pos_y in range(grid_height): grid_y = grid_pos_y * height for row in range(args.rows): #find the top left point for the strokes min_y = grid_y + row * (max_height + args.column_padding) + args.page_margin for column in range(args.columns): min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin item = row*args.columns + column checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/items_per_page * len(checkpoints)) logger.info(f"For item {item}/{items_per_page} use checkpoint idx {checkpoint_idx}") checkpoint = checkpoints[checkpoint_idx] loadCheckpoint(args.model_dir, checkpoint) isLast = (row == args.rows-1 and column == args.columns-1) if isLast and args.last_is_target: strokes = target_stroke else: strokes = decode(target_z, temperature=1) # strokes = target_stroke if args.last_in_group and isLast: path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) dwgGroups[-1].add(path) elif args.split_paths: paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height) i = 0 for path in paths: group = dwgGroups[i % args.nr_of_paths] i+=1 group.add(path) else: path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) dwgGroups[0].add(path) pbar.update() if args.create_grid: logger.info("Create grid") grid_length = 50 grid_group = inkscapeDwg.layer(label='grid') with tqdm(total=(grid_width+1)*(grid_height+1)) as pbar: for i in range(grid_width + 1): for j in range(grid_height + 1): sx = i * width sy = j * height - grid_length sx2 = sx sy2 = sy+2*grid_length p = f"M{sx},{sy} L{sx2},{sy2}" path = dwg.path(p).stroke('black',1).fill("none") grid_group.add(path) sx = i * width - grid_length sy = j * height sx2 = sx+ 2*grid_length sy2 = sy p = f"M{sx},{sy} L{sx2},{sy2}" path = dwg.path(p).stroke('black',1).fill("none") grid_group.add(path) pbar.update() dwg.add(grid_group) dwg.save()