birthcard/create_card.py

352 lines
12 KiB
Python

# import the required libraries
import numpy as np
import time
import random
import pickle
import codecs
import collections
import os
import math
import json
import tensorflow as tf
from six.moves import xrange
import logging
import argparse
import svgwrite
from tqdm import tqdm
import re
import glob
import math
# import our command line tools
from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('card')
argParser = argparse.ArgumentParser(description='Create postcard')
argParser.add_argument(
'--data_dir',
type=str,
default='./datasets/naam4',
help=''
)
argParser.add_argument(
'--model_dir',
type=str,
default='./models/naam4',
)
argParser.add_argument(
'--output_file',
type=str,
default='generated/naam4.svg',
)
argParser.add_argument(
'--target_sample',
type=int,
default=202,
)
argParser.add_argument(
'--width',
type=str,
default='100mm',
)
argParser.add_argument(
'--height',
type=str,
default='190mm',
)
argParser.add_argument(
'--rows',
type=int,
default=13,
)
argParser.add_argument(
'--columns',
type=int,
default=5,
)
argParser.add_argument(
'--column_padding',
type=int,
default=10,
)
argParser.add_argument(
'--page_margin',
type=int,
default=70,
)
argParser.add_argument(
'--max_checkpoint_factor',
type=float,
default=1.,
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
)
argParser.add_argument(
'--split_paths',
action='store_true',
help='If a stroke contains multiple steps/paths, split these into separate paths/strokes'
)
argParser.add_argument(
'--nr_of_paths',
type=int,
default=3,
help='If split_paths is given, this number defines the number of groups the paths should be split over'
)
argParser.add_argument(
'--last_is_target',
action='store_true',
help='Last item not generated, but as the given target'
)
argParser.add_argument(
'--last_in_group',
action='store_true',
help='If set, put the last rendition into a separate group'
)
argParser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Debug logging'
)
args = argParser.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
lift_pen = 1
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
factor_x = max_width / (max_x-min_x)
factor_y = max_height / (max_y-min_y)
# assuming both < 1
factor = factor_y if factor_y < factor_x else factor_x
abs_x = start_x - min_x * factor
abs_y = start_y - min_y * factor
p = "M%s,%s " % (abs_x, abs_y)
# p = "M%s,%s " % (0, 0)
command = "m"
for i in xrange(len(strokes)):
if (lift_pen == 1):
command = "m"
elif (command != "l"):
command = "l"
else:
command = ""
x = float(strokes[i,0])*factor
y = float(strokes[i,1])*factor
lift_pen = strokes[i, 2]
p += f"{command}{x:.5},{y:.5} "
the_color = "black"
stroke_width = 1
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
def splitStrokes(strokes):
"""
turn [[x,y,0],[x,y,1],[x,y,0]] into [[[x,y,0],[x,y,1]], [[x,y,0]]]
"""
subStrokes = []
for stroke in strokes:
subStrokes.append(stroke)
if stroke[2] == 1 and subStrokes:
yield subStrokes
subStrokes = []
if subStrokes:
yield subStrokes
def strokesToSplitPaths(dwg, strokes, start_x, start_y, max_width, max_height):
"""
Stroke not to a single path but each lift pen to a next path
"""
lift_pen = 1
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
factor_x = max_width / (max_x-min_x)
factor_y = max_height / (max_y-min_y)
# assuming both < 1
factor = factor_y if factor_y < factor_x else factor_x
abs_x = start_x - min_x * factor
abs_y = start_y - min_y * factor
paths = []
for path in splitStrokes(strokes):
if len(path)<2:
logger.warning(f"Too few strokes to create path: {path}")
continue
for i in xrange(len(path)):
# i += 1 #first item is the move, which we already got above
x = float(path[i][0])*factor
y = float(path[i][1])*factor
abs_x += x
abs_y += y
if i == 0:
p = "M%s,%s l" % (abs_x, abs_y)
else:
p += f"{x:.5},{y:.5} "
the_color = "black"
stroke_width = 1
paths.append(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
return paths
# little function that displays vector images and saves them to .svg
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(strokesToPath(dwg, strokes, factor))
dwg.save()
def load_env_compatible(data_dir, model_dir):
"""Loads environment for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
return load_dataset(data_dir, model_params, inference_mode=True)
def load_model_compatible(model_dir):
"""Loads model for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
model_params.batch_size = 1 # only sample one at a time
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
eval_model_params.use_input_dropout = 0
eval_model_params.use_recurrent_dropout = 0
eval_model_params.use_output_dropout = 0
eval_model_params.is_training = 0
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
sample_model_params.max_seq_len = 1 # sample one point at a time
return [model_params, eval_model_params, sample_model_params]
# some basic initialisation (done before encode() and decode() as they use these variables)
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the (latest) weights from checkpoint into our model
load_checkpoint(sess, args.model_dir)
def encode(input_strokes):
"""
Encode input image into vector
"""
strokes = to_big_strokes(input_strokes, 614).tolist()
strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)]
print(seq_len)
# draw_strokes(to_normal_strokes(np.array(strokes)))
print(np.array([strokes]).shape)
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
def decode(z_input=None, temperature=0.1, factor=0.2):
"""
Decode vector into strokes (image)
"""
z = None
if z_input is not None:
z = [z_input]
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
strokes = to_normal_strokes(sample_strokes)
return strokes
def getCheckpoints(model_dir):
files = glob.glob(os.path.join(model_dir,'vector-*.meta'))
checkpoints = []
for file in files:
checkpoints.append(int(file.split('-')[-1][:-5]))
return sorted(checkpoints)
currentCheckpoint = None
def loadCheckpoint(model_dir, nr):
if nr == currentCheckpoint:
return
# loads the intermediate weights from checkpoint into our model
saver = tf.train.Saver(tf.global_variables())
# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(model_dir, f"vector-{nr}"))
dims = (args.width, args.height)
width = int(re.findall('\d+',args.width)[0])*10
height = int(re.findall('\d+',args.height)[0])*10
# padding = 20
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0)
dwgGroups = [svgwrite.container.Group(id=f"g{i}") for i in range(requiredGroups)]
for group in dwgGroups:
dwg.add(group)
checkpoints = getCheckpoints(args.model_dir)
item_count = args.rows*args.columns
# factor = dataset_baseheight/ (height/args.rows)
# initialize z
target_stroke = test_set.strokes[args.target_sample]
target_z = encode(target_stroke)
max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1))) / args.columns
max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows
with tqdm(total=item_count) as pbar:
for row in range(args.rows):
#find the top left point for the strokes
min_y = row * (max_height + args.column_padding) + args.page_margin
for column in range(args.columns):
min_x = column * (max_width + args.column_padding) + args.page_margin
item = row*args.columns + column
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
checkpoint = checkpoints[checkpoint_idx]
loadCheckpoint(args.model_dir, checkpoint)
isLast = (row == args.rows-1 and column == args.columns-1)
if isLast and args.last_is_target:
strokes = target_stroke
else:
strokes = decode(target_z, temperature=1)
if args.last_in_group and isLast:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwgGroups[-1].add(path)
elif args.split_paths:
paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height)
i = 0
for path in paths:
group = dwgGroups[i % args.nr_of_paths]
i+=1
group.add(path)
else:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwgGroups[0].add(path)
pbar.update()
dwg.save()