415 lines
14 KiB
Python
415 lines
14 KiB
Python
# import the required libraries
|
|
import numpy as np
|
|
import time
|
|
import random
|
|
import pickle
|
|
import codecs
|
|
import collections
|
|
import os
|
|
import math
|
|
import json
|
|
import tensorflow as tf
|
|
from six.moves import xrange
|
|
import logging
|
|
import argparse
|
|
import svgwrite
|
|
from tqdm import tqdm
|
|
import re
|
|
import glob
|
|
import math
|
|
from svgwrite.extensions import Inkscape
|
|
|
|
|
|
# import our command line tools
|
|
from magenta.models.sketch_rnn.sketch_rnn_train import *
|
|
from magenta.models.sketch_rnn.model import *
|
|
from magenta.models.sketch_rnn.utils import *
|
|
from magenta.models.sketch_rnn.rnn import *
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger('card')
|
|
|
|
argParser = argparse.ArgumentParser(description='Create postcard')
|
|
argParser.add_argument(
|
|
'--data_dir',
|
|
type=str,
|
|
default='./datasets/naam4',
|
|
help=''
|
|
)
|
|
argParser.add_argument(
|
|
'--model_dir',
|
|
type=str,
|
|
default='./models/naam4',
|
|
)
|
|
argParser.add_argument(
|
|
'--output_file',
|
|
type=str,
|
|
default='generated/naam4.svg',
|
|
)
|
|
argParser.add_argument(
|
|
'--target_sample',
|
|
type=int,
|
|
default=202,
|
|
)
|
|
argParser.add_argument(
|
|
'--width',
|
|
type=str,
|
|
default='99mm',
|
|
)
|
|
argParser.add_argument(
|
|
'--height',
|
|
type=str,
|
|
default='190mm',
|
|
)
|
|
argParser.add_argument(
|
|
'--rows',
|
|
type=int,
|
|
default=13,
|
|
)
|
|
argParser.add_argument(
|
|
'--columns',
|
|
type=int,
|
|
default=5,
|
|
)
|
|
argParser.add_argument(
|
|
'--column_padding',
|
|
type=int,
|
|
default=10,
|
|
)
|
|
argParser.add_argument(
|
|
'--page_margin',
|
|
type=int,
|
|
default=70,
|
|
)
|
|
argParser.add_argument(
|
|
'--max_checkpoint_factor',
|
|
type=float,
|
|
default=1.,
|
|
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
|
|
)
|
|
argParser.add_argument(
|
|
'--split_paths',
|
|
action='store_true',
|
|
help='If a stroke contains multiple steps/paths, split these into separate paths/strokes'
|
|
)
|
|
argParser.add_argument(
|
|
'--nr_of_paths',
|
|
type=int,
|
|
default=3,
|
|
help='If split_paths is given, this number defines the number of groups the paths should be split over'
|
|
)
|
|
argParser.add_argument(
|
|
'--last_is_target',
|
|
action='store_true',
|
|
help='Last item not generated, but as the given target'
|
|
)
|
|
argParser.add_argument(
|
|
'--last_in_group',
|
|
action='store_true',
|
|
help='If set, put the last rendition into a separate group'
|
|
)
|
|
argParser.add_argument(
|
|
'--create_grid',
|
|
action='store_true',
|
|
help='Create a grid with cutting lines'
|
|
)
|
|
argParser.add_argument(
|
|
'--grid_width',
|
|
type=int,
|
|
default=3,
|
|
help='Grid items x'
|
|
)
|
|
argParser.add_argument(
|
|
'--grid_height',
|
|
type=int,
|
|
default=2,
|
|
help='Grid items y'
|
|
)
|
|
argParser.add_argument(
|
|
'--verbose',
|
|
'-v',
|
|
action='store_true',
|
|
help='Debug logging'
|
|
)
|
|
args = argParser.parse_args()
|
|
|
|
if args.verbose:
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
|
|
lift_pen = 1
|
|
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
|
|
factor_x = max_width / (max_x-min_x)
|
|
factor_y = max_height / (max_y-min_y)
|
|
# assuming both < 1
|
|
factor = factor_y if factor_y < factor_x else factor_x
|
|
|
|
abs_x = start_x - min_x * factor
|
|
abs_y = start_y - min_y * factor
|
|
p = "M%s,%s " % (abs_x, abs_y)
|
|
# p = "M%s,%s " % (0, 0)
|
|
command = "m"
|
|
for i in xrange(len(strokes)):
|
|
if (lift_pen == 1):
|
|
command = "m"
|
|
elif (command != "l"):
|
|
command = "l"
|
|
else:
|
|
command = ""
|
|
x = float(strokes[i,0])*factor
|
|
y = float(strokes[i,1])*factor
|
|
lift_pen = strokes[i, 2]
|
|
p += f"{command}{x:.5},{y:.5} "
|
|
the_color = "black"
|
|
stroke_width = 1
|
|
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
|
|
|
|
def splitStrokes(strokes):
|
|
"""
|
|
turn [[x,y,0],[x,y,1],[x,y,0]] into [[[x,y,0],[x,y,1]], [[x,y,0]]]
|
|
"""
|
|
subStrokes = []
|
|
for stroke in strokes:
|
|
subStrokes.append(stroke)
|
|
if stroke[2] == 1 and subStrokes:
|
|
yield subStrokes
|
|
subStrokes = []
|
|
if subStrokes:
|
|
yield subStrokes
|
|
|
|
def strokesToSplitPaths(dwg, strokes, start_x, start_y, max_width, max_height):
|
|
"""
|
|
Stroke not to a single path but each lift pen to a next path
|
|
"""
|
|
|
|
lift_pen = 1
|
|
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
|
|
factor_x = max_width / (max_x-min_x)
|
|
factor_y = max_height / (max_y-min_y)
|
|
# assuming both < 1
|
|
factor = factor_y if factor_y < factor_x else factor_x
|
|
|
|
abs_x = start_x - min_x * factor
|
|
abs_y = start_y - min_y * factor
|
|
paths = []
|
|
for path in splitStrokes(strokes):
|
|
if len(path)<2:
|
|
logger.warning(f"Too few strokes to create path: {path}")
|
|
continue
|
|
for i in xrange(len(path)):
|
|
# i += 1 #first item is the move, which we already got above
|
|
x = float(path[i][0])*factor
|
|
y = float(path[i][1])*factor
|
|
abs_x += x
|
|
abs_y += y
|
|
|
|
if i == 0:
|
|
p = "M%s,%s l" % (abs_x, abs_y)
|
|
else:
|
|
p += f"{x:.5},{y:.5} "
|
|
the_color = "black"
|
|
stroke_width = 1
|
|
paths.append(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
|
|
return paths
|
|
|
|
|
|
# little function that displays vector images and saves them to .svg
|
|
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
|
|
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
|
|
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
|
|
dims = (50 + max_x - min_x, 50 + max_y - min_y)
|
|
dwg = svgwrite.Drawing(svg_filename, size=dims)
|
|
dwg.add(strokesToPath(dwg, strokes, factor))
|
|
dwg.save()
|
|
|
|
def load_env_compatible(data_dir, model_dir):
|
|
"""Loads environment for inference mode, used in jupyter notebook."""
|
|
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
|
# to work with depreciated tf.HParams functionality
|
|
model_params = sketch_rnn_model.get_default_hparams()
|
|
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
|
data = json.load(f)
|
|
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
|
for fix in fix_list:
|
|
data[fix] = (data[fix] == 1)
|
|
model_params.parse_json(json.dumps(data))
|
|
return load_dataset(data_dir, model_params, inference_mode=True)
|
|
|
|
def load_model_compatible(model_dir):
|
|
"""Loads model for inference mode, used in jupyter notebook."""
|
|
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
|
# to work with depreciated tf.HParams functionality
|
|
model_params = sketch_rnn_model.get_default_hparams()
|
|
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
|
data = json.load(f)
|
|
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
|
for fix in fix_list:
|
|
data[fix] = (data[fix] == 1)
|
|
model_params.parse_json(json.dumps(data))
|
|
|
|
model_params.batch_size = 1 # only sample one at a time
|
|
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
|
|
eval_model_params.use_input_dropout = 0
|
|
eval_model_params.use_recurrent_dropout = 0
|
|
eval_model_params.use_output_dropout = 0
|
|
eval_model_params.is_training = 0
|
|
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
|
|
sample_model_params.max_seq_len = 1 # sample one point at a time
|
|
return [model_params, eval_model_params, sample_model_params]
|
|
|
|
|
|
# some basic initialisation (done before encode() and decode() as they use these variables)
|
|
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir)
|
|
# construct the sketch-rnn model here:
|
|
reset_graph()
|
|
model = Model(hps_model)
|
|
eval_model = Model(eval_hps_model, reuse=True)
|
|
sample_model = Model(sample_hps_model, reuse=True)
|
|
|
|
sess = tf.InteractiveSession()
|
|
sess.run(tf.global_variables_initializer())
|
|
# loads the (latest) weights from checkpoint into our model
|
|
load_checkpoint(sess, args.model_dir)
|
|
|
|
def encode(input_strokes):
|
|
"""
|
|
Encode input image into vector
|
|
"""
|
|
strokes = to_big_strokes(input_strokes, 614).tolist()
|
|
strokes.insert(0, [0, 0, 1, 0, 0])
|
|
seq_len = [len(input_strokes)]
|
|
print(seq_len)
|
|
# draw_strokes(to_normal_strokes(np.array(strokes)))
|
|
print(np.array([strokes]).shape)
|
|
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
|
|
|
|
def decode(z_input=None, temperature=0.1, factor=0.2):
|
|
"""
|
|
Decode vector into strokes (image)
|
|
"""
|
|
z = None
|
|
if z_input is not None:
|
|
z = [z_input]
|
|
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
|
|
strokes = to_normal_strokes(sample_strokes)
|
|
return strokes
|
|
|
|
def getCheckpoints(model_dir):
|
|
files = glob.glob(os.path.join(model_dir,'vector-*.meta'))
|
|
checkpoints = []
|
|
for file in files:
|
|
checkpoints.append(int(file.split('-')[-1][:-5]))
|
|
return sorted(checkpoints)
|
|
|
|
currentCheckpoint = None
|
|
def loadCheckpoint(model_dir, nr):
|
|
if nr == currentCheckpoint:
|
|
return
|
|
# loads the intermediate weights from checkpoint into our model
|
|
saver = tf.train.Saver(tf.global_variables())
|
|
# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
|
|
saver.restore(sess, os.path.join(model_dir, f"vector-{nr}"))
|
|
|
|
width = int(re.findall('\d+',args.width)[0])*10
|
|
height = int(re.findall('\d+',args.height)[0])*10
|
|
grid_height = args.grid_height if args.create_grid else 1
|
|
grid_width = args.grid_width if args.create_grid else 1
|
|
|
|
# Override given dimension with grid info
|
|
page_height = width/10*grid_width
|
|
page_width = height/10*grid_height
|
|
dims = (f"{page_height}mm", f"{page_width}mm")
|
|
|
|
# padding = 20
|
|
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width*grid_width} {height*grid_height}")
|
|
inkscapeDwg = Inkscape(dwg)
|
|
requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0)
|
|
dwgGroups = [inkscapeDwg.layer(label=f"g{i}") for i in range(requiredGroups)]
|
|
for group in dwgGroups:
|
|
dwg.add(group)
|
|
|
|
checkpoints = getCheckpoints(args.model_dir)
|
|
items_per_page = args.rows*args.columns
|
|
item_count = items_per_page*grid_width*grid_height
|
|
|
|
logger.info(f"Checkpoints: {checkpoints}, factor: {args.max_checkpoint_factor}")
|
|
max_checkpoint_idx = math.floor(args.max_checkpoint_factor * len(checkpoints)-1)
|
|
logger.info(f"Max chkpt: {max_checkpoint_idx}: {checkpoints[max_checkpoint_idx]}")
|
|
|
|
# factor = dataset_baseheight/ (height/args.rows)
|
|
|
|
# initialize z
|
|
target_stroke = test_set.strokes[args.target_sample]
|
|
target_z = encode(target_stroke)
|
|
|
|
max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1))) / args.columns
|
|
max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows
|
|
|
|
with tqdm(total=item_count) as pbar:
|
|
for grid_pos_x in range(grid_width):
|
|
grid_x = grid_pos_x * width
|
|
for grid_pos_y in range(grid_height):
|
|
grid_y = grid_pos_y * height
|
|
|
|
for row in range(args.rows):
|
|
#find the top left point for the strokes
|
|
min_y = grid_y + row * (max_height + args.column_padding) + args.page_margin
|
|
for column in range(args.columns):
|
|
min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin
|
|
item = row*args.columns + column
|
|
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/items_per_page * len(checkpoints))
|
|
logger.info(f"For item {item}/{items_per_page} use checkpoint idx {checkpoint_idx}")
|
|
checkpoint = checkpoints[checkpoint_idx]
|
|
loadCheckpoint(args.model_dir, checkpoint)
|
|
|
|
isLast = (row == args.rows-1 and column == args.columns-1)
|
|
|
|
if isLast and args.last_is_target:
|
|
strokes = target_stroke
|
|
else:
|
|
strokes = decode(target_z, temperature=1)
|
|
# strokes = target_stroke
|
|
|
|
if args.last_in_group and isLast:
|
|
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
|
dwgGroups[-1].add(path)
|
|
elif args.split_paths:
|
|
paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height)
|
|
i = 0
|
|
for path in paths:
|
|
group = dwgGroups[i % args.nr_of_paths]
|
|
i+=1
|
|
group.add(path)
|
|
else:
|
|
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
|
dwgGroups[0].add(path)
|
|
|
|
pbar.update()
|
|
|
|
if args.create_grid:
|
|
logger.info("Create grid")
|
|
|
|
grid_length = 50
|
|
grid_group = inkscapeDwg.layer(label='grid')
|
|
with tqdm(total=(grid_width+1)*(grid_height+1)) as pbar:
|
|
for i in range(grid_width + 1):
|
|
for j in range(grid_height + 1):
|
|
sx = i * width
|
|
sy = j * height - grid_length
|
|
sx2 = sx
|
|
sy2 = sy+2*grid_length
|
|
p = f"M{sx},{sy} L{sx2},{sy2}"
|
|
path = dwg.path(p).stroke('black',1).fill("none")
|
|
grid_group.add(path)
|
|
sx = i * width - grid_length
|
|
sy = j * height
|
|
sx2 = sx+ 2*grid_length
|
|
sy2 = sy
|
|
p = f"M{sx},{sy} L{sx2},{sy2}"
|
|
path = dwg.path(p).stroke('black',1).fill("none")
|
|
grid_group.add(path)
|
|
|
|
pbar.update()
|
|
dwg.add(grid_group)
|
|
dwg.save()
|