birthcard/create_card.py

416 lines
14 KiB
Python

# import the required libraries
import numpy as np
import time
import random
import pickle
import codecs
import collections
import os
import math
import json
import tensorflow as tf
from six.moves import xrange
import logging
import argparse
import svgwrite
from tqdm import tqdm
import re
import glob
import math
from svgwrite.extensions import Inkscape
# import our command line tools
from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('card')
argParser = argparse.ArgumentParser(description='Create postcard')
argParser.add_argument(
'--data_dir',
type=str,
default='./datasets/naam4',
help=''
)
argParser.add_argument(
'--model_dir',
type=str,
default='./models/naam4',
)
argParser.add_argument(
'--output_file',
type=str,
default='generated/naam4.svg',
)
argParser.add_argument(
'--target_sample',
type=int,
default=202,
)
argParser.add_argument(
'--width',
type=str,
default='99mm',
)
argParser.add_argument(
'--height',
type=str,
default='190mm',
)
argParser.add_argument(
'--rows',
type=int,
default=13,
)
argParser.add_argument(
'--columns',
type=int,
default=5,
)
argParser.add_argument(
'--column_padding',
type=int,
default=10,
)
argParser.add_argument(
'--page_margin',
type=int,
default=70,
)
argParser.add_argument(
'--max_checkpoint_factor',
type=float,
default=1.,
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
)
argParser.add_argument(
'--split_paths',
action='store_true',
help='If a stroke contains multiple steps/paths, split these into separate paths/strokes'
)
argParser.add_argument(
'--nr_of_paths',
type=int,
default=3,
help='If split_paths is given, this number defines the number of groups the paths should be split over'
)
argParser.add_argument(
'--last_is_target',
action='store_true',
help='Last item not generated, but as the given target'
)
argParser.add_argument(
'--last_in_group',
action='store_true',
help='If set, put the last rendition into a separate group'
)
argParser.add_argument(
'--create_grid',
action='store_true',
help='Create a grid with cutting lines'
)
argParser.add_argument(
'--grid_width',
type=int,
default=3,
help='Grid items x'
)
argParser.add_argument(
'--grid_height',
type=int,
default=2,
help='Grid items y'
)
argParser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Debug logging'
)
args = argParser.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
lift_pen = 1
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
factor_x = max_width / (max_x-min_x)
factor_y = max_height / (max_y-min_y)
# assuming both < 1
factor = factor_y if factor_y < factor_x else factor_x
abs_x = start_x - min_x * factor
abs_y = start_y - min_y * factor
p = "M%s,%s " % (abs_x, abs_y)
# p = "M%s,%s " % (0, 0)
command = "m"
for i in xrange(len(strokes)):
if (lift_pen == 1):
command = "m"
elif (command != "l"):
command = "l"
else:
command = ""
x = float(strokes[i,0])*factor
y = float(strokes[i,1])*factor
lift_pen = strokes[i, 2]
p += f"{command}{x:.5},{y:.5} "
the_color = "black"
stroke_width = 1
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
def splitStrokes(strokes):
"""
turn [[x,y,0],[x,y,1],[x,y,0]] into [[[x,y,0],[x,y,1]], [[x,y,0]]]
"""
subStrokes = []
for stroke in strokes:
subStrokes.append(stroke)
if stroke[2] == 1 and subStrokes:
yield subStrokes
subStrokes = []
if subStrokes:
yield subStrokes
def strokesToSplitPaths(dwg, strokes, start_x, start_y, max_width, max_height):
"""
Stroke not to a single path but each lift pen to a next path
"""
lift_pen = 1
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
factor_x = max_width / (max_x-min_x)
factor_y = max_height / (max_y-min_y)
# assuming both < 1
factor = factor_y if factor_y < factor_x else factor_x
abs_x = start_x - min_x * factor
abs_y = start_y - min_y * factor
paths = []
for path in splitStrokes(strokes):
if len(path)<2:
logger.warning(f"Too few strokes to create path: {path}")
continue
for i in xrange(len(path)):
# i += 1 #first item is the move, which we already got above
x = float(path[i][0])*factor
y = float(path[i][1])*factor
abs_x += x
abs_y += y
if i == 0:
p = "M%s,%s l" % (abs_x, abs_y)
else:
p += f"{x:.5},{y:.5} "
the_color = "black"
stroke_width = 1
paths.append(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
return paths
# little function that displays vector images and saves them to .svg
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(strokesToPath(dwg, strokes, factor))
dwg.save()
def load_env_compatible(data_dir, model_dir):
"""Loads environment for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
return load_dataset(data_dir, model_params, inference_mode=True)
def load_model_compatible(model_dir):
"""Loads model for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
model_params.batch_size = 1 # only sample one at a time
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
eval_model_params.use_input_dropout = 0
eval_model_params.use_recurrent_dropout = 0
eval_model_params.use_output_dropout = 0
eval_model_params.is_training = 0
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
sample_model_params.max_seq_len = 1 # sample one point at a time
return [model_params, eval_model_params, sample_model_params]
# some basic initialisation (done before encode() and decode() as they use these variables)
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the (latest) weights from checkpoint into our model
load_checkpoint(sess, args.model_dir)
def encode(input_strokes):
"""
Encode input image into vector
"""
strokes = to_big_strokes(input_strokes, 614).tolist()
strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)]
print(seq_len)
# draw_strokes(to_normal_strokes(np.array(strokes)))
print(np.array([strokes]).shape)
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
def decode(z_input=None, temperature=0.1, factor=0.2):
"""
Decode vector into strokes (image)
"""
z = None
if z_input is not None:
z = [z_input]
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
strokes = to_normal_strokes(sample_strokes)
return strokes
def getCheckpoints(model_dir):
files = glob.glob(os.path.join(model_dir,'vector-*.meta'))
checkpoints = []
for file in files:
checkpoints.append(int(file.split('-')[-1][:-5]))
return sorted(checkpoints)
currentCheckpoint = None
def loadCheckpoint(model_dir, nr):
if nr == currentCheckpoint:
return
# loads the intermediate weights from checkpoint into our model
saver = tf.train.Saver(tf.global_variables())
# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(model_dir, f"vector-{nr}"))
width = int(re.findall('\d+',args.width)[0])*10
height = int(re.findall('\d+',args.height)[0])*10
grid_height = args.grid_height if args.create_grid else 1
grid_width = args.grid_width if args.create_grid else 1
# Override given dimension with grid info
page_height = width/10*grid_width
page_width = height/10*grid_height
dims = (f"{page_height}mm", f"{page_width}mm")
# padding = 20
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width*grid_width} {height*grid_height}")
inkscapeDwg = Inkscape(dwg)
requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0)
dwgGroups = [inkscapeDwg.layer(label=f"g{i}") for i in range(requiredGroups)]
for group in dwgGroups:
dwg.add(group)
checkpoints = getCheckpoints(args.model_dir)
items_per_page = args.rows*args.columns
item_count = items_per_page*grid_width*grid_height
logger.info(f"Checkpoints: {checkpoints}, factor: {args.max_checkpoint_factor}")
max_checkpoint_idx = math.floor(args.max_checkpoint_factor * len(checkpoints)-1)
logger.info(f"Max chkpt: {max_checkpoint_idx}: {checkpoints[max_checkpoint_idx]}")
# factor = dataset_baseheight/ (height/args.rows)
# initialize z
target_stroke = test_set.strokes[args.target_sample]
target_z = encode(target_stroke)
max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1))) / args.columns
max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows
with tqdm(total=item_count) as pbar:
for grid_pos_x in range(grid_width):
grid_x = grid_pos_x * width
for grid_pos_y in range(grid_height):
grid_y = grid_pos_y * height
for row in range(args.rows):
#find the top left point for the strokes
min_y = grid_y + row * (max_height + args.column_padding) + args.page_margin
for column in range(args.columns):
min_x = grid_x + column * (max_width + args.column_padding) + args.page_margin
item = row*args.columns + column
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/items_per_page * len(checkpoints))
logger.info(f"For item {item}/{items_per_page} use checkpoint idx {checkpoint_idx}")
checkpoint = checkpoints[checkpoint_idx]
loadCheckpoint(args.model_dir, checkpoint)
isLast = (row == args.rows-1 and column == args.columns-1)
if isLast and args.last_is_target:
strokes = target_stroke
else:
strokes = decode(target_z, temperature=1)
# strokes = target_stroke
if args.last_in_group and isLast:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwgGroups[-1].add(path)
elif args.split_paths:
paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height)
i = 0
for path in paths:
group = dwgGroups[i % args.nr_of_paths]
i+=1
group.add(path)
else:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwgGroups[0].add(path)
pbar.update()
if args.create_grid:
logger.info("Create grid")
grid_length = 50
grid_group = inkscapeDwg.layer(label='grid')
with tqdm(total=(grid_width+1)*(grid_height+1)) as pbar:
for i in range(grid_width + 1):
for j in range(grid_height + 1):
sx = i * width
sy = j * height - grid_length
sx2 = sx
sy2 = sy+2*grid_length
p = f"M{sx},{sy} L{sx2},{sy2}"
path = dwg.path(p).stroke('black',1).fill("none")
grid_group.add(path)
sx = i * width - grid_length
sy = j * height
sx2 = sx+ 2*grid_length
sy2 = sy
p = f"M{sx},{sy} L{sx2},{sy2}"
path = dwg.path(p).stroke('black',1).fill("none")
grid_group.add(path)
pbar.update()
dwg.add(grid_group)
dwg.save()