work & thoughts in progress
This commit is contained in:
commit
35a398c42e
5 changed files with 2802 additions and 0 deletions
4
README.md
Normal file
4
README.md
Normal file
|
@ -0,0 +1,4 @@
|
|||
Successfull on naam-simple:
|
||||
```
|
||||
sketch_rnn_train --log_root=models/naam-simple --data_dir=datasets/naam-simple --hparams="data_set=[diede.npz],dec_model=layer_norm,dec_rnn_size=200,enc_model=layer_norm,enc_rnn_size=200,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=False,num_steps=1000"
|
||||
```
|
2227
Sketch_RNN.ipynb
Normal file
2227
Sketch_RNN.ipynb
Normal file
File diff suppressed because one or more lines are too long
212
create_card.py
Normal file
212
create_card.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
# import the required libraries
|
||||
import numpy as np
|
||||
import time
|
||||
import random
|
||||
import pickle
|
||||
import codecs
|
||||
import collections
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import tensorflow as tf
|
||||
from six.moves import xrange
|
||||
import logging
|
||||
import argparse
|
||||
import svgwrite
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
|
||||
|
||||
# import our command line tools
|
||||
from magenta.models.sketch_rnn.sketch_rnn_train import *
|
||||
from magenta.models.sketch_rnn.model import *
|
||||
from magenta.models.sketch_rnn.utils import *
|
||||
from magenta.models.sketch_rnn.rnn import *
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger('card')
|
||||
|
||||
argParser = argparse.ArgumentParser(description='Create postcard')
|
||||
argParser.add_argument(
|
||||
'--data_dir',
|
||||
type=str,
|
||||
default='./datasets/naam3',
|
||||
help=''
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--model_dir',
|
||||
type=str,
|
||||
default='./models/naam3',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='generated/naam3',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--width',
|
||||
type=str,
|
||||
default='90mm',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--height',
|
||||
type=str,
|
||||
default='150mm',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--rows',
|
||||
type=int,
|
||||
default=13,
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--columns',
|
||||
type=int,
|
||||
default=5,
|
||||
)
|
||||
# argParser.add_argument(
|
||||
# '--output_file',
|
||||
# type=str,
|
||||
# default='card.svg',
|
||||
# )
|
||||
argParser.add_argument(
|
||||
'--verbose',
|
||||
'-v',
|
||||
action='store_true',
|
||||
help='Debug logging'
|
||||
)
|
||||
args = argParser.parse_args()
|
||||
|
||||
dataset_baseheight = 10
|
||||
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
|
||||
lift_pen = 1
|
||||
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
|
||||
abs_x = start_x - min_x
|
||||
abs_y = start_y - min_y
|
||||
p = "M%s,%s " % (abs_x, abs_y)
|
||||
# p = "M%s,%s " % (0, 0)
|
||||
command = "m"
|
||||
for i in xrange(len(strokes)):
|
||||
if (lift_pen == 1):
|
||||
command = "m"
|
||||
elif (command != "l"):
|
||||
command = "l"
|
||||
else:
|
||||
command = ""
|
||||
x = float(strokes[i,0])/factor
|
||||
y = float(strokes[i,1])/factor
|
||||
lift_pen = strokes[i, 2]
|
||||
p += f"{command}{x:.5},{y:.5} "
|
||||
the_color = "black"
|
||||
stroke_width = 1
|
||||
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
|
||||
|
||||
# little function that displays vector images and saves them to .svg
|
||||
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
|
||||
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
|
||||
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
|
||||
dims = (50 + max_x - min_x, 50 + max_y - min_y)
|
||||
dwg = svgwrite.Drawing(svg_filename, size=dims)
|
||||
dwg.add(strokesToPath(dwg, strokes, factor))
|
||||
dwg.save()
|
||||
|
||||
def load_env_compatible(data_dir, model_dir):
|
||||
"""Loads environment for inference mode, used in jupyter notebook."""
|
||||
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
||||
# to work with depreciated tf.HParams functionality
|
||||
model_params = sketch_rnn_model.get_default_hparams()
|
||||
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
||||
data = json.load(f)
|
||||
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
||||
for fix in fix_list:
|
||||
data[fix] = (data[fix] == 1)
|
||||
model_params.parse_json(json.dumps(data))
|
||||
return load_dataset(data_dir, model_params, inference_mode=True)
|
||||
|
||||
def load_model_compatible(model_dir):
|
||||
"""Loads model for inference mode, used in jupyter notebook."""
|
||||
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
||||
# to work with depreciated tf.HParams functionality
|
||||
model_params = sketch_rnn_model.get_default_hparams()
|
||||
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
||||
data = json.load(f)
|
||||
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
||||
for fix in fix_list:
|
||||
data[fix] = (data[fix] == 1)
|
||||
model_params.parse_json(json.dumps(data))
|
||||
|
||||
model_params.batch_size = 1 # only sample one at a time
|
||||
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
|
||||
eval_model_params.use_input_dropout = 0
|
||||
eval_model_params.use_recurrent_dropout = 0
|
||||
eval_model_params.use_output_dropout = 0
|
||||
eval_model_params.is_training = 0
|
||||
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
|
||||
sample_model_params.max_seq_len = 1 # sample one point at a time
|
||||
return [model_params, eval_model_params, sample_model_params]
|
||||
|
||||
|
||||
# some basic initialisation (done before encode() and decode() as they use these variables)
|
||||
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir)
|
||||
# construct the sketch-rnn model here:
|
||||
reset_graph()
|
||||
model = Model(hps_model)
|
||||
eval_model = Model(eval_hps_model, reuse=True)
|
||||
sample_model = Model(sample_hps_model, reuse=True)
|
||||
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
# loads the weights from checkpoint into our model
|
||||
load_checkpoint(sess, args.model_dir)
|
||||
|
||||
def encode(input_strokes):
|
||||
"""
|
||||
Encode input image into vector
|
||||
"""
|
||||
strokes = to_big_strokes(input_strokes, 614).tolist()
|
||||
strokes.insert(0, [0, 0, 1, 0, 0])
|
||||
seq_len = [len(input_strokes)]
|
||||
print(seq_len)
|
||||
draw_strokes(to_normal_strokes(np.array(strokes)))
|
||||
print(np.array([strokes]).shape)
|
||||
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
|
||||
|
||||
def decode(z_input=None, temperature=0.1, factor=0.2):
|
||||
"""
|
||||
Decode vector into strokes (image)
|
||||
"""
|
||||
z = None
|
||||
if z_input is not None:
|
||||
z = [z_input]
|
||||
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
|
||||
strokes = to_normal_strokes(sample_strokes)
|
||||
return strokes
|
||||
|
||||
dims = (args.width, args.height)
|
||||
width = int(re.findall('\d+',args.width)[0])*10
|
||||
height = int(re.findall('\d+',args.height)[0])*10
|
||||
# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
|
||||
|
||||
item_count = args.rows*args.columns
|
||||
|
||||
factor = dataset_baseheight/ (height/args.rows)
|
||||
|
||||
with tqdm(total=item_count) as pbar:
|
||||
for row in range(args.rows):
|
||||
start_y = row * (float(height) / args.rows)
|
||||
row_temp = .1+row * (1./args.rows)
|
||||
for column in range(args.columns):
|
||||
strokes = decode(temperature=row_temp)
|
||||
fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
|
||||
draw_strokes(strokes, svg_filename=fn)
|
||||
# current_nr = row * args.columns + column
|
||||
# temp = .01+current_nr * (1./item_count)
|
||||
# start_x = column * (float(width) / args.columns)
|
||||
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
|
||||
# dwg.add(path)
|
||||
pbar.update()
|
||||
# dwg.save()
|
150
create_dataset.py
Normal file
150
create_dataset.py
Normal file
|
@ -0,0 +1,150 @@
|
|||
from rdp import rdp
|
||||
import xml.etree.ElementTree as ET
|
||||
from svg.path import parse_path
|
||||
import numpy as np
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger('dataset')
|
||||
|
||||
argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test')
|
||||
argParser.add_argument(
|
||||
'--src',
|
||||
type=str,
|
||||
default='./data/naam',
|
||||
help=''
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--dataset_dir',
|
||||
type=str,
|
||||
default='./datasets/naam',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--multiply',
|
||||
type=int,
|
||||
default=2,
|
||||
help="If you dont have enough items, automatically multiply it so there's at least 100 per set.",
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--verbose',
|
||||
'-v',
|
||||
action='store_true',
|
||||
help='Debug logging'
|
||||
)
|
||||
args = argParser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def getStroke3FromSvg(filename):
|
||||
"""
|
||||
Get drawing as stroke-3 format.
|
||||
Gets each group as different drawing
|
||||
points as as [dx, dy, pen state]
|
||||
"""
|
||||
logger.debug(f"Read {filename}")
|
||||
s = ET.parse(filename)
|
||||
root = s.getroot()
|
||||
groups = root.findall("{http://www.w3.org/2000/svg}g")
|
||||
|
||||
sketches = []
|
||||
for group in groups:
|
||||
svg_paths = group.findall("{http://www.w3.org/2000/svg}path")
|
||||
paths = []
|
||||
|
||||
min_x = None
|
||||
min_y = None
|
||||
max_y = None
|
||||
|
||||
for p in svg_paths:
|
||||
path = parse_path(p.get("d"))
|
||||
|
||||
points = [[point.end.real, point.end.imag] for point in path]
|
||||
x_points = np.array([p[0] for p in points])
|
||||
y_points = np.array([p[1] for p in points])
|
||||
|
||||
if min_x is None:
|
||||
min_x = min(x_points)
|
||||
min_y = min(y_points)
|
||||
max_y = max(y_points)
|
||||
else:
|
||||
min_x = min(min_x, min(x_points))
|
||||
min_y = min(min_y, min(y_points))
|
||||
max_y = max(max_y, max(y_points))
|
||||
|
||||
points = np.array([[x_points[i], y_points[i]] for i in range(len(points))])
|
||||
paths.append(points)
|
||||
|
||||
# scale normalize & crop
|
||||
scale = 512 / (max_y - min_y)
|
||||
|
||||
prev_x = None
|
||||
prev_y = None
|
||||
|
||||
strokes = []
|
||||
for path in paths:
|
||||
path[:,0] -= min_x
|
||||
path[:,1] -= min_y
|
||||
path *= scale
|
||||
#simplify using Ramer-Douglas-Peucker., see https://github.com/tensorflow/magenta/tree/master/magenta/models/sketch_rnn
|
||||
if(len(path) > 800):
|
||||
logger.debug(f'simplify {len(path)} factor 3.5')
|
||||
path = rdp(path, epsilon=3.5)
|
||||
logger.debug(f'\tnow {len(path)}')
|
||||
if(len(path) > 490):
|
||||
logger.debug(f'simplify {len(path)} factor 3')
|
||||
path = rdp(path, epsilon=3)
|
||||
logger.debug(f'\tnow {len(path)}')
|
||||
if(len(path) > 300):
|
||||
logger.debug(f'simplify {len(path)} factor 2')
|
||||
path = rdp(path, epsilon=2.0)
|
||||
logger.debug(f'\tnow {len(path)}')
|
||||
for point in path:
|
||||
if prev_x is not None and prev_y is not None:
|
||||
strokes.append([int(point[0] - prev_x), int(point[1] - prev_y), 0])
|
||||
|
||||
prev_x = point[0]
|
||||
prev_y = point[1]
|
||||
|
||||
# mark lifting of pen
|
||||
strokes[-1][2] = 1
|
||||
|
||||
logger.debug(f"Paths: {len(strokes)}")
|
||||
# strokes = np.array(strokes, dtype=np.int16)
|
||||
sketches.append(strokes)
|
||||
|
||||
return sketches
|
||||
|
||||
def main():
|
||||
sketches = {}
|
||||
for dirName, subdirList, fileList in os.walk(args.src):
|
||||
for fname in fileList:
|
||||
filename = os.path.join(dirName, fname)
|
||||
className = fname[:-4].rstrip('0123456789.- ')
|
||||
if not className in sketches:
|
||||
sketches[className] = []
|
||||
sketches[className].extend(getStroke3FromSvg(filename))
|
||||
|
||||
|
||||
for className in sketches:
|
||||
filename = os.path.join(args.dataset_dir, className + '.npz')
|
||||
itemsForClass = len(sketches[className])
|
||||
if itemsForClass < 100:
|
||||
logger.info(f"Loop to have at least 100 for class {className} (now {itemsForClass})")
|
||||
extras = []
|
||||
for i in range(100 - itemsForClass):
|
||||
extras.append(sketches[className][i % itemsForClass])
|
||||
sketches[className].extend(extras)
|
||||
logger.debug(f"Now {len(sketches[className])} samples for {className}")
|
||||
sets = sketches[className]
|
||||
# exit()
|
||||
np.savez_compressed(filename, train=sets, valid=sets, test=sets)
|
||||
logger.info(f"Saved {len(sets)} samples in {filename}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
209
generate_svg.py
Normal file
209
generate_svg.py
Normal file
|
@ -0,0 +1,209 @@
|
|||
# import the required libraries
|
||||
import numpy as np
|
||||
import time
|
||||
import random
|
||||
import pickle
|
||||
import codecs
|
||||
import collections
|
||||
import os
|
||||
import math
|
||||
import json
|
||||
import tensorflow as tf
|
||||
from six.moves import xrange
|
||||
import argparse
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
|
||||
# libraries required for visualisation:
|
||||
from IPython.display import SVG, display
|
||||
import PIL
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# set numpy output to something sensible
|
||||
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
|
||||
|
||||
import svgwrite # conda install -c omnia svgwrite=1.1.6
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger('dataset')
|
||||
|
||||
argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test')
|
||||
argParser.add_argument(
|
||||
'--dataset_dir',
|
||||
type=str,
|
||||
default='./datasets/naam',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--model_dir',
|
||||
type=str,
|
||||
default='./models/naam',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--generated_dir',
|
||||
type=str,
|
||||
default='./generated/naam',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--verbose',
|
||||
'-v',
|
||||
action='store_true',
|
||||
help='Debug logging'
|
||||
)
|
||||
args = argParser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
|
||||
data_dir = args.dataset_dir
|
||||
model_dir = args.model_dir
|
||||
|
||||
|
||||
tf.logging.info("TensorFlow Version: %s", tf.__version__)
|
||||
|
||||
# import our command line tools
|
||||
from magenta.models.sketch_rnn.sketch_rnn_train import *
|
||||
from magenta.models.sketch_rnn.model import *
|
||||
from magenta.models.sketch_rnn.utils import *
|
||||
from magenta.models.sketch_rnn.rnn import *
|
||||
|
||||
# little function that displays vector images and saves them to .svg
|
||||
def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
|
||||
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
|
||||
min_x, max_x, min_y, max_y = get_bounds(data, factor)
|
||||
dims = (50 + max_x - min_x, 50 + max_y - min_y)
|
||||
dwg = svgwrite.Drawing(svg_filename, size=dims)
|
||||
# dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
|
||||
lift_pen = 1
|
||||
abs_x = 25 - min_x
|
||||
abs_y = 25 - min_y
|
||||
p = "M%s,%s " % (abs_x, abs_y)
|
||||
command = "m"
|
||||
for i in xrange(len(data)):
|
||||
if (lift_pen == 1):
|
||||
command = "m"
|
||||
elif (command != "l"):
|
||||
command = "l"
|
||||
else:
|
||||
command = ""
|
||||
x = float(data[i,0])/factor
|
||||
y = float(data[i,1])/factor
|
||||
lift_pen = data[i, 2]
|
||||
p += command+str(x)+","+str(y)+" "
|
||||
the_color = "black"
|
||||
stroke_width = 1
|
||||
dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
|
||||
dwg.save()
|
||||
# display(SVG(dwg.tostring()))
|
||||
|
||||
# generate a 2D grid of many vector drawings
|
||||
def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):
|
||||
def get_start_and_end(x):
|
||||
x = np.array(x)
|
||||
x = x[:, 0:2]
|
||||
x_start = x[0]
|
||||
x_end = x.sum(axis=0)
|
||||
x = x.cumsum(axis=0)
|
||||
x_max = x.max(axis=0)
|
||||
x_min = x.min(axis=0)
|
||||
center_loc = (x_max+x_min)*0.5
|
||||
return x_start-center_loc, x_end
|
||||
x_pos = 0.0
|
||||
y_pos = 0.0
|
||||
result = [[x_pos, y_pos, 1]]
|
||||
for sample in s_list:
|
||||
s = sample[0]
|
||||
grid_loc = sample[1]
|
||||
grid_y = grid_loc[0]*grid_space+grid_space*0.5
|
||||
grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5
|
||||
start_loc, delta_pos = get_start_and_end(s)
|
||||
|
||||
loc_x = start_loc[0]
|
||||
loc_y = start_loc[1]
|
||||
new_x_pos = grid_x+loc_x
|
||||
new_y_pos = grid_y+loc_y
|
||||
result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])
|
||||
|
||||
result += s.tolist()
|
||||
result[-1][2] = 1
|
||||
x_pos = new_x_pos+delta_pos[0]
|
||||
y_pos = new_y_pos+delta_pos[1]
|
||||
return np.array(result)
|
||||
|
||||
def load_env_compatible(data_dir, model_dir):
|
||||
"""Loads environment for inference mode, used in jupyter notebook."""
|
||||
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
||||
# to work with depreciated tf.HParams functionality
|
||||
model_params = sketch_rnn_model.get_default_hparams()
|
||||
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
||||
data = json.load(f)
|
||||
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
||||
for fix in fix_list:
|
||||
data[fix] = (data[fix] == 1)
|
||||
model_params.parse_json(json.dumps(data))
|
||||
return load_dataset(data_dir, model_params, inference_mode=True)
|
||||
|
||||
def load_model_compatible(model_dir):
|
||||
"""Loads model for inference mode, used in jupyter notebook."""
|
||||
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
|
||||
# to work with depreciated tf.HParams functionality
|
||||
model_params = sketch_rnn_model.get_default_hparams()
|
||||
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
|
||||
data = json.load(f)
|
||||
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
|
||||
for fix in fix_list:
|
||||
data[fix] = (data[fix] == 1)
|
||||
model_params.parse_json(json.dumps(data))
|
||||
|
||||
model_params.batch_size = 1 # only sample one at a time
|
||||
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
|
||||
eval_model_params.use_input_dropout = 0
|
||||
eval_model_params.use_recurrent_dropout = 0
|
||||
eval_model_params.use_output_dropout = 0
|
||||
eval_model_params.is_training = 0
|
||||
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
|
||||
sample_model_params.max_seq_len = 1 # sample one point at a time
|
||||
return [model_params, eval_model_params, sample_model_params]
|
||||
|
||||
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(data_dir, model_dir)
|
||||
# construct the sketch-rnn model here:
|
||||
reset_graph()
|
||||
model = Model(hps_model)
|
||||
eval_model = Model(eval_hps_model, reuse=True)
|
||||
sample_model = Model(sample_hps_model, reuse=True)
|
||||
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# loads the weights from checkpoint into our model
|
||||
load_checkpoint(sess, args.model_dir)
|
||||
|
||||
# We define two convenience functions to encode a stroke into a latent vector, and decode from latent vector to stroke.
|
||||
def encode(input_strokes):
|
||||
strokes = to_big_strokes(input_strokes).tolist()
|
||||
strokes.insert(0, [0, 0, 1, 0, 0])
|
||||
seq_len = [len(input_strokes)]
|
||||
draw_strokes(to_normal_strokes(np.array(strokes)))
|
||||
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
|
||||
|
||||
def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2, filename=None):
|
||||
z = None
|
||||
if z_input is not None:
|
||||
z = [z_input]
|
||||
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
|
||||
strokes = to_normal_strokes(sample_strokes)
|
||||
if draw_mode:
|
||||
draw_strokes(strokes, factor, svg_filename = filename)
|
||||
return strokes
|
||||
|
||||
with tqdm(total=10*50) as pbar:
|
||||
for i in range(10):
|
||||
temperature = float(i+1) / 10.
|
||||
for j in range(50):
|
||||
filename = os.path.join(args.generated_dir, f"generated{temperature}-{j:03d}.svg")
|
||||
_ = decode(temperature=temperature, filename=filename)
|
||||
pbar.update()
|
Loading…
Reference in a new issue