work & thoughts in progress

This commit is contained in:
Ruben van de Ven 2019-08-25 17:19:27 +02:00
commit 35a398c42e
5 changed files with 2802 additions and 0 deletions

4
README.md Normal file
View file

@ -0,0 +1,4 @@
Successfull on naam-simple:
```
sketch_rnn_train --log_root=models/naam-simple --data_dir=datasets/naam-simple --hparams="data_set=[diede.npz],dec_model=layer_norm,dec_rnn_size=200,enc_model=layer_norm,enc_rnn_size=200,save_every=100,grad_clip=1.0,use_recurrent_dropout=0,conditional=False,num_steps=1000"
```

2227
Sketch_RNN.ipynb Normal file

File diff suppressed because one or more lines are too long

212
create_card.py Normal file
View file

@ -0,0 +1,212 @@
# import the required libraries
import numpy as np
import time
import random
import pickle
import codecs
import collections
import os
import math
import json
import tensorflow as tf
from six.moves import xrange
import logging
import argparse
import svgwrite
from tqdm import tqdm
import re
# import our command line tools
from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('card')
argParser = argparse.ArgumentParser(description='Create postcard')
argParser.add_argument(
'--data_dir',
type=str,
default='./datasets/naam3',
help=''
)
argParser.add_argument(
'--model_dir',
type=str,
default='./models/naam3',
)
argParser.add_argument(
'--output_dir',
type=str,
default='generated/naam3',
)
argParser.add_argument(
'--width',
type=str,
default='90mm',
)
argParser.add_argument(
'--height',
type=str,
default='150mm',
)
argParser.add_argument(
'--rows',
type=int,
default=13,
)
argParser.add_argument(
'--columns',
type=int,
default=5,
)
# argParser.add_argument(
# '--output_file',
# type=str,
# default='card.svg',
# )
argParser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Debug logging'
)
args = argParser.parse_args()
dataset_baseheight = 10
if args.verbose:
logger.setLevel(logging.DEBUG)
def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
lift_pen = 1
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
abs_x = start_x - min_x
abs_y = start_y - min_y
p = "M%s,%s " % (abs_x, abs_y)
# p = "M%s,%s " % (0, 0)
command = "m"
for i in xrange(len(strokes)):
if (lift_pen == 1):
command = "m"
elif (command != "l"):
command = "l"
else:
command = ""
x = float(strokes[i,0])/factor
y = float(strokes[i,1])/factor
lift_pen = strokes[i, 2]
p += f"{command}{x:.5},{y:.5} "
the_color = "black"
stroke_width = 1
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
# little function that displays vector images and saves them to .svg
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
dwg.add(strokesToPath(dwg, strokes, factor))
dwg.save()
def load_env_compatible(data_dir, model_dir):
"""Loads environment for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
return load_dataset(data_dir, model_params, inference_mode=True)
def load_model_compatible(model_dir):
"""Loads model for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
model_params.batch_size = 1 # only sample one at a time
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
eval_model_params.use_input_dropout = 0
eval_model_params.use_recurrent_dropout = 0
eval_model_params.use_output_dropout = 0
eval_model_params.is_training = 0
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
sample_model_params.max_seq_len = 1 # sample one point at a time
return [model_params, eval_model_params, sample_model_params]
# some basic initialisation (done before encode() and decode() as they use these variables)
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(args.data_dir, args.model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess, args.model_dir)
def encode(input_strokes):
"""
Encode input image into vector
"""
strokes = to_big_strokes(input_strokes, 614).tolist()
strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)]
print(seq_len)
draw_strokes(to_normal_strokes(np.array(strokes)))
print(np.array([strokes]).shape)
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
def decode(z_input=None, temperature=0.1, factor=0.2):
"""
Decode vector into strokes (image)
"""
z = None
if z_input is not None:
z = [z_input]
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
strokes = to_normal_strokes(sample_strokes)
return strokes
dims = (args.width, args.height)
width = int(re.findall('\d+',args.width)[0])*10
height = int(re.findall('\d+',args.height)[0])*10
# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
item_count = args.rows*args.columns
factor = dataset_baseheight/ (height/args.rows)
with tqdm(total=item_count) as pbar:
for row in range(args.rows):
start_y = row * (float(height) / args.rows)
row_temp = .1+row * (1./args.rows)
for column in range(args.columns):
strokes = decode(temperature=row_temp)
fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
draw_strokes(strokes, svg_filename=fn)
# current_nr = row * args.columns + column
# temp = .01+current_nr * (1./item_count)
# start_x = column * (float(width) / args.columns)
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
# dwg.add(path)
pbar.update()
# dwg.save()

150
create_dataset.py Normal file
View file

@ -0,0 +1,150 @@
from rdp import rdp
import xml.etree.ElementTree as ET
from svg.path import parse_path
import numpy as np
import logging
import argparse
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('dataset')
argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test')
argParser.add_argument(
'--src',
type=str,
default='./data/naam',
help=''
)
argParser.add_argument(
'--dataset_dir',
type=str,
default='./datasets/naam',
)
argParser.add_argument(
'--multiply',
type=int,
default=2,
help="If you dont have enough items, automatically multiply it so there's at least 100 per set.",
)
argParser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Debug logging'
)
args = argParser.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
def getStroke3FromSvg(filename):
"""
Get drawing as stroke-3 format.
Gets each group as different drawing
points as as [dx, dy, pen state]
"""
logger.debug(f"Read {filename}")
s = ET.parse(filename)
root = s.getroot()
groups = root.findall("{http://www.w3.org/2000/svg}g")
sketches = []
for group in groups:
svg_paths = group.findall("{http://www.w3.org/2000/svg}path")
paths = []
min_x = None
min_y = None
max_y = None
for p in svg_paths:
path = parse_path(p.get("d"))
points = [[point.end.real, point.end.imag] for point in path]
x_points = np.array([p[0] for p in points])
y_points = np.array([p[1] for p in points])
if min_x is None:
min_x = min(x_points)
min_y = min(y_points)
max_y = max(y_points)
else:
min_x = min(min_x, min(x_points))
min_y = min(min_y, min(y_points))
max_y = max(max_y, max(y_points))
points = np.array([[x_points[i], y_points[i]] for i in range(len(points))])
paths.append(points)
# scale normalize & crop
scale = 512 / (max_y - min_y)
prev_x = None
prev_y = None
strokes = []
for path in paths:
path[:,0] -= min_x
path[:,1] -= min_y
path *= scale
#simplify using Ramer-Douglas-Peucker., see https://github.com/tensorflow/magenta/tree/master/magenta/models/sketch_rnn
if(len(path) > 800):
logger.debug(f'simplify {len(path)} factor 3.5')
path = rdp(path, epsilon=3.5)
logger.debug(f'\tnow {len(path)}')
if(len(path) > 490):
logger.debug(f'simplify {len(path)} factor 3')
path = rdp(path, epsilon=3)
logger.debug(f'\tnow {len(path)}')
if(len(path) > 300):
logger.debug(f'simplify {len(path)} factor 2')
path = rdp(path, epsilon=2.0)
logger.debug(f'\tnow {len(path)}')
for point in path:
if prev_x is not None and prev_y is not None:
strokes.append([int(point[0] - prev_x), int(point[1] - prev_y), 0])
prev_x = point[0]
prev_y = point[1]
# mark lifting of pen
strokes[-1][2] = 1
logger.debug(f"Paths: {len(strokes)}")
# strokes = np.array(strokes, dtype=np.int16)
sketches.append(strokes)
return sketches
def main():
sketches = {}
for dirName, subdirList, fileList in os.walk(args.src):
for fname in fileList:
filename = os.path.join(dirName, fname)
className = fname[:-4].rstrip('0123456789.- ')
if not className in sketches:
sketches[className] = []
sketches[className].extend(getStroke3FromSvg(filename))
for className in sketches:
filename = os.path.join(args.dataset_dir, className + '.npz')
itemsForClass = len(sketches[className])
if itemsForClass < 100:
logger.info(f"Loop to have at least 100 for class {className} (now {itemsForClass})")
extras = []
for i in range(100 - itemsForClass):
extras.append(sketches[className][i % itemsForClass])
sketches[className].extend(extras)
logger.debug(f"Now {len(sketches[className])} samples for {className}")
sets = sketches[className]
# exit()
np.savez_compressed(filename, train=sets, valid=sets, test=sets)
logger.info(f"Saved {len(sets)} samples in {filename}")
if __name__ == '__main__':
main()

209
generate_svg.py Normal file
View file

@ -0,0 +1,209 @@
# import the required libraries
import numpy as np
import time
import random
import pickle
import codecs
import collections
import os
import math
import json
import tensorflow as tf
from six.moves import xrange
import argparse
import logging
from tqdm import tqdm
# libraries required for visualisation:
from IPython.display import SVG, display
import PIL
from PIL import Image
import matplotlib.pyplot as plt
# set numpy output to something sensible
np.set_printoptions(precision=8, edgeitems=6, linewidth=200, suppress=True)
import svgwrite # conda install -c omnia svgwrite=1.1.6
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('dataset')
argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test')
argParser.add_argument(
'--dataset_dir',
type=str,
default='./datasets/naam',
)
argParser.add_argument(
'--model_dir',
type=str,
default='./models/naam',
)
argParser.add_argument(
'--generated_dir',
type=str,
default='./generated/naam',
)
argParser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Debug logging'
)
args = argParser.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
data_dir = args.dataset_dir
model_dir = args.model_dir
tf.logging.info("TensorFlow Version: %s", tf.__version__)
# import our command line tools
from magenta.models.sketch_rnn.sketch_rnn_train import *
from magenta.models.sketch_rnn.model import *
from magenta.models.sketch_rnn.utils import *
from magenta.models.sketch_rnn.rnn import *
# little function that displays vector images and saves them to .svg
def draw_strokes(data, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
min_x, max_x, min_y, max_y = get_bounds(data, factor)
dims = (50 + max_x - min_x, 50 + max_y - min_y)
dwg = svgwrite.Drawing(svg_filename, size=dims)
# dwg.add(dwg.rect(insert=(0, 0), size=dims,fill='white'))
lift_pen = 1
abs_x = 25 - min_x
abs_y = 25 - min_y
p = "M%s,%s " % (abs_x, abs_y)
command = "m"
for i in xrange(len(data)):
if (lift_pen == 1):
command = "m"
elif (command != "l"):
command = "l"
else:
command = ""
x = float(data[i,0])/factor
y = float(data[i,1])/factor
lift_pen = data[i, 2]
p += command+str(x)+","+str(y)+" "
the_color = "black"
stroke_width = 1
dwg.add(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
dwg.save()
# display(SVG(dwg.tostring()))
# generate a 2D grid of many vector drawings
def make_grid_svg(s_list, grid_space=10.0, grid_space_x=16.0):
def get_start_and_end(x):
x = np.array(x)
x = x[:, 0:2]
x_start = x[0]
x_end = x.sum(axis=0)
x = x.cumsum(axis=0)
x_max = x.max(axis=0)
x_min = x.min(axis=0)
center_loc = (x_max+x_min)*0.5
return x_start-center_loc, x_end
x_pos = 0.0
y_pos = 0.0
result = [[x_pos, y_pos, 1]]
for sample in s_list:
s = sample[0]
grid_loc = sample[1]
grid_y = grid_loc[0]*grid_space+grid_space*0.5
grid_x = grid_loc[1]*grid_space_x+grid_space_x*0.5
start_loc, delta_pos = get_start_and_end(s)
loc_x = start_loc[0]
loc_y = start_loc[1]
new_x_pos = grid_x+loc_x
new_y_pos = grid_y+loc_y
result.append([new_x_pos-x_pos, new_y_pos-y_pos, 0])
result += s.tolist()
result[-1][2] = 1
x_pos = new_x_pos+delta_pos[0]
y_pos = new_y_pos+delta_pos[1]
return np.array(result)
def load_env_compatible(data_dir, model_dir):
"""Loads environment for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
return load_dataset(data_dir, model_params, inference_mode=True)
def load_model_compatible(model_dir):
"""Loads model for inference mode, used in jupyter notebook."""
# modified https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/sketch_rnn_train.py
# to work with depreciated tf.HParams functionality
model_params = sketch_rnn_model.get_default_hparams()
with tf.gfile.Open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
data = json.load(f)
fix_list = ['conditional', 'is_training', 'use_input_dropout', 'use_output_dropout', 'use_recurrent_dropout']
for fix in fix_list:
data[fix] = (data[fix] == 1)
model_params.parse_json(json.dumps(data))
model_params.batch_size = 1 # only sample one at a time
eval_model_params = sketch_rnn_model.copy_hparams(model_params)
eval_model_params.use_input_dropout = 0
eval_model_params.use_recurrent_dropout = 0
eval_model_params.use_output_dropout = 0
eval_model_params.is_training = 0
sample_model_params = sketch_rnn_model.copy_hparams(eval_model_params)
sample_model_params.max_seq_len = 1 # sample one point at a time
return [model_params, eval_model_params, sample_model_params]
[train_set, valid_set, test_set, hps_model, eval_hps_model, sample_hps_model] = load_env_compatible(data_dir, model_dir)
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess, args.model_dir)
# We define two convenience functions to encode a stroke into a latent vector, and decode from latent vector to stroke.
def encode(input_strokes):
strokes = to_big_strokes(input_strokes).tolist()
strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)]
draw_strokes(to_normal_strokes(np.array(strokes)))
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
def decode(z_input=None, draw_mode=True, temperature=0.1, factor=0.2, filename=None):
z = None
if z_input is not None:
z = [z_input]
sample_strokes, m = sample(sess, sample_model, seq_len=eval_model.hps.max_seq_len, temperature=temperature, z=z)
strokes = to_normal_strokes(sample_strokes)
if draw_mode:
draw_strokes(strokes, factor, svg_filename = filename)
return strokes
with tqdm(total=10*50) as pbar:
for i in range(10):
temperature = float(i+1) / 10.
for j in range(50):
filename = os.path.join(args.generated_dir, f"generated{temperature}-{j:03d}.svg")
_ = decode(temperature=temperature, filename=filename)
pbar.update()