213 lines
6.9 KiB
Python
213 lines
6.9 KiB
Python
# import the required libraries
|
|
import numpy as np
|
|
import time
|
|
import random
|
|
import pickle
|
|
import codecs
|
|
import collections
|
|
import os
|
|
import math
|
|
import json
|
|
import tensorflow as tf
|
|
from six.moves import xrange
|
|
import logging
|
|
import argparse
|
|
import svgwrite
|
|
from tqdm import tqdm
|
|
import re
|
|
|
|
|
|
# import our command line tools
|
|
from magenta.models.sketch_rnn.sketch_rnn_train import *
|
|
from magenta.models.sketch_rnn.model import *
|
|
from magenta.models.sketch_rnn.utils import *
|
|
from magenta.models.sketch_rnn.rnn import *
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger('card')
|
|
|
|
argParser = argparse.ArgumentParser(description='Create postcard')
|
|
argParser.add_argument(
|
|
'--data_dir',
|
|
type=str,
|
|
default='./datasets/naam3',
|
|
help=''
|
|
)
|
|
argParser.add_argument(
|
|
'--model_dir',
|
|
type=str,
|
|
default='./models/naam3',
|
|
)
|
|
argParser.add_argument(
|
|
'--output_dir',
|
|
type=str,
|
|
default='generated/naam3',
|
|
)
|
|
argParser.add_argument(
|
|
'--width',
|
|
type=str,
|
|
default='90mm',
|
|
)
|
|
argParser.add_argument(
|
|
'--height',
|
|
type=str,
|
|
default='150mm',
|
|
)
|
|
argParser.add_argument(
|
|
'--rows',
|
|
type=int,
|
|
default=13,
|
|
)
|
|
argParser.add_argument(
|
|
'--columns',
|
|
type=int,
|
|
default=5,
|
|
)
|
|
# argParser.add_argument(
|
|
# '--output_file',
|
|
# type=str,
|
|
# default='card.svg',
|
|
# )
|
|
argParser.add_argument(
|
|
'--verbose',
|
|
'-v',
|
|
action='store_true',
|
|
help='Debug logging'
|
|
)
|
|
args = argParser.parse_args()
|
|
|
|
dataset_baseheight = 10
|
|
|
|
if args.verbose:
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
|
|
lift_pen = 1
|
|
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
|
|
abs_x = start_x - min_x
|
|
abs_y = start_y - min_y
|
|
p = "M%s,%s " % (abs_x, abs_y)
|
|
# p = "M%s,%s " % (0, 0)
|
|
command = "m"
|
|
for i in xrange(len(strokes)):
|
|
if (lift_pen == 1):
|
|
command = "m"
|
|
elif (command != "l"):
|
|
command = "l"
|
|
else:
|
|
command = ""
|
|
x = float(strokes[i,0])/factor
|
|
y = float(strokes[i,1])/factor
|
|
lift_pen = strokes[i, 2]
|
|
p += f"{command}{x:.5},{y:.5} "
|
|
the_color = "black"
|
|
stroke_width = 1
|
|
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
|
|
|
|
# little function that displays vector images and saves them to .svg
|
|
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
|
|
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
|
|
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
|
|
dims = (50 + max_x - min_x, 50 + max_y - min_y)
|
|
dwg = svgwrite.Drawing(svg_filename, size=dims)
|
|
dwg.add(strokesToPath(dwg, strokes, factor))
|
|
dwg.save()
|
|
|
|
def load_env_compatible(data_dir, model_dir):
|
|
"""Loads environment for inference mode, used in jupyter notebook."""
|
|
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
|
# to work with depreciated tf.HParams functionality
|
|
model_params = sketch_rnn_model.get_default_hparams()
|
|
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
|
data = json.load(f)
|
|
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
|
for fix in fix_list:
|
|
data[fix] = (data[fix] == 1)
|
|
model_params.parse_json(json.dumps(data))
|
|
return load_dataset(data_dir, model_params, inference_mode=True)
|
|
|
|
def load_model_compatible(model_dir):
|
|
"""Loads model for inference mode, used in jupyter notebook."""
|
|
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
|
# to work with depreciated tf.HParams functionality
|
|
model_params = sketch_rnn_model.get_default_hparams()
|
|
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
|
data = json.load(f)
|
|
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
|
for fix in fix_list:
|
|
data[fix] = (data[fix] == 1)
|
|
model_params.parse_json(json.dumps(data))
|
|
|
|
model_params.batch_size = 1 # only sample one at a time
|
|
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
|
|
eval_model_params.use_input_dropout = 0
|
|
eval_model_params.use_recurrent_dropout = 0
|
|
eval_model_params.use_output_dropout = 0
|
|
eval_model_params.is_training = 0
|
|
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
|
|
sample_model_params.max_seq_len = 1 # sample one point at a time
|
|
return [model_params, eval_model_params, sample_model_params]
|
|
|
|
|
|
# some basic initialisation (done before encode() and decode() as they use these variables)
|
|
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir)
|
|
# construct the sketch-rnn model here:
|
|
reset_graph()
|
|
model = Model(hps_model)
|
|
eval_model = Model(eval_hps_model, reuse=True)
|
|
sample_model = Model(sample_hps_model, reuse=True)
|
|
|
|
sess = tf.InteractiveSession()
|
|
sess.run(tf.global_variables_initializer())
|
|
# loads the weights from checkpoint into our model
|
|
load_checkpoint(sess, args.model_dir)
|
|
|
|
def encode(input_strokes):
|
|
"""
|
|
Encode input image into vector
|
|
"""
|
|
strokes = to_big_strokes(input_strokes, 614).tolist()
|
|
strokes.insert(0, [0, 0, 1, 0, 0])
|
|
seq_len = [len(input_strokes)]
|
|
print(seq_len)
|
|
draw_strokes(to_normal_strokes(np.array(strokes)))
|
|
print(np.array([strokes]).shape)
|
|
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
|
|
|
|
def decode(z_input=None, temperature=0.1, factor=0.2):
|
|
"""
|
|
Decode vector into strokes (image)
|
|
"""
|
|
z = None
|
|
if z_input is not None:
|
|
z = [z_input]
|
|
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
|
|
strokes = to_normal_strokes(sample_strokes)
|
|
return strokes
|
|
|
|
dims = (args.width, args.height)
|
|
width = int(re.findall('\d+',args.width)[0])*10
|
|
height = int(re.findall('\d+',args.height)[0])*10
|
|
# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
|
|
|
|
item_count = args.rows*args.columns
|
|
|
|
factor = dataset_baseheight/ (height/args.rows)
|
|
|
|
with tqdm(total=item_count) as pbar:
|
|
for row in range(args.rows):
|
|
start_y = row * (float(height) / args.rows)
|
|
row_temp = .1+row * (1./args.rows)
|
|
for column in range(args.columns):
|
|
strokes = decode(temperature=row_temp)
|
|
fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
|
|
draw_strokes(strokes, svg_filename=fn)
|
|
# current_nr = row * args.columns + column
|
|
# temp = .01+current_nr * (1./item_count)
|
|
# start_x = column * (float(width) / args.columns)
|
|
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
|
|
# dwg.add(path)
|
|
pbar.update()
|
|
# dwg.save()
|