birthcard/create_card.py

213 lines
6.9 KiB
Python

# import the required libraries
import numpy as np
import time
import random
import pickle
import codecs
import collections
import os
import math
import json
import tensorflow as tf
from six.moves import xrange
import logging
import argparse
import svgwrite
from tqdm import tqdm
import re
# import our command line tools
from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('card')
argParser = argparse.ArgumentParser(description='Create postcard')
argParser.add_argument(
'--data_dir',
type=str,
default='./datasets/naam3',
help=''
)
argParser.add_argument(
'--model_dir',
type=str,
default='./models/naam3',
)
argParser.add_argument(
'--output_dir',
type=str,
default='generated/naam3',
)
argParser.add_argument(
'--width',
type=str,
default='90mm',
)
argParser.add_argument(
'--height',
type=str,
default='150mm',
)
argParser.add_argument(
'--rows',
type=int,
default=13,
)
argParser.add_argument(
'--columns',
type=int,
default=5,
)
# argParser.add_argument(
# '--output_file',
# type=str,
# default='card.svg',
# )
argParser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Debug logging'
)
args = argParser.parse_args()
dataset_baseheight = 10
if args.verbose:
logger.setLevel(logging.DEBUG)
def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
lift_pen = 1
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
abs_x = start_x - min_x
abs_y = start_y - min_y
p = "M%s,%s " % (abs_x, abs_y)
# p = "M%s,%s " % (0, 0)
command = "m"
for i in xrange(len(strokes)):
if (lift_pen == 1):
command = "m"
elif (command != "l"):
command = "l"
else:
command = ""
x = float(strokes[i,0])/factor
y = float(strokes[i,1])/factor
lift_pen = strokes[i, 2]
p += f"{command}{x:.5},{y:.5} "
the_color = "black"
stroke_width = 1
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
# little function that displays vector images and saves them to .svg
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(strokesToPath(dwg, strokes, factor))
dwg.save()
def load_env_compatible(data_dir, model_dir):
"""Loads environment for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
return load_dataset(data_dir, model_params, inference_mode=True)
def load_model_compatible(model_dir):
"""Loads model for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
model_params.batch_size = 1 # only sample one at a time
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
eval_model_params.use_input_dropout = 0
eval_model_params.use_recurrent_dropout = 0
eval_model_params.use_output_dropout = 0
eval_model_params.is_training = 0
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
sample_model_params.max_seq_len = 1 # sample one point at a time
return [model_params, eval_model_params, sample_model_params]
# some basic initialisation (done before encode() and decode() as they use these variables)
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess, args.model_dir)
def encode(input_strokes):
"""
Encode input image into vector
"""
strokes = to_big_strokes(input_strokes, 614).tolist()
strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)]
print(seq_len)
draw_strokes(to_normal_strokes(np.array(strokes)))
print(np.array([strokes]).shape)
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
def decode(z_input=None, temperature=0.1, factor=0.2):
"""
Decode vector into strokes (image)
"""
z = None
if z_input is not None:
z = [z_input]
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
strokes = to_normal_strokes(sample_strokes)
return strokes
dims = (args.width, args.height)
width = int(re.findall('\d+',args.width)[0])*10
height = int(re.findall('\d+',args.height)[0])*10
# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
item_count = args.rows*args.columns
factor = dataset_baseheight/ (height/args.rows)
with tqdm(total=item_count) as pbar:
for row in range(args.rows):
start_y = row * (float(height) / args.rows)
row_temp = .1+row * (1./args.rows)
for column in range(args.columns):
strokes = decode(temperature=row_temp)
fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
draw_strokes(strokes, svg_filename=fn)
# current_nr = row * args.columns + column
# temp = .01+current_nr * (1./item_count)
# start_x = column * (float(width) / args.columns)
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
# dwg.add(path)
pbar.update()
# dwg.save()