diff --git a/Sketch_RNN.ipynb b/Sketch_RNN.ipynb
index dd9e1e9..b17cc9a 100644
--- a/Sketch_RNN.ipynb
+++ b/Sketch_RNN.ipynb
@@ -546,7 +546,7 @@
},
{
"cell_type": "code",
- "execution_count": 189,
+ "execution_count": 201,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@@ -560,7 +560,7 @@
{
"data": {
"image/svg+xml": [
- ""
+ ""
],
"text/plain": [
""
@@ -573,7 +573,7 @@
"source": [
"# get a sample drawing from the test set, and render it to .svg\n",
"stroke = test_set.random_sample()\n",
- "stroke=test_set.strokes[202]\n",
+ "stroke=test_set.strokes[252]\n",
"draw_strokes(stroke)"
]
},
diff --git a/create_card.py b/create_card.py
index 6bb85dd..7e02c57 100644
--- a/create_card.py
+++ b/create_card.py
@@ -15,6 +15,8 @@ import argparse
import svgwrite
from tqdm import tqdm
import re
+import glob
+import math
# import our command line tools
@@ -31,28 +33,33 @@ argParser = argparse.ArgumentParser(description='Create postcard')
argParser.add_argument(
'--data_dir',
type=str,
- default='./datasets/naam3',
+ default='./datasets/naam4',
help=''
)
argParser.add_argument(
'--model_dir',
type=str,
- default='./models/naam3',
+ default='./models/naam4',
)
argParser.add_argument(
- '--output_dir',
+ '--output_file',
type=str,
- default='generated/naam3',
+ default='generated/naam4.svg',
+ )
+argParser.add_argument(
+ '--target_sample',
+ type=int,
+ default=202,
)
argParser.add_argument(
'--width',
type=str,
- default='90mm',
+ default='100mm',
)
argParser.add_argument(
'--height',
type=str,
- default='150mm',
+ default='190mm',
)
argParser.add_argument(
'--rows',
@@ -64,6 +71,12 @@ argParser.add_argument(
type=int,
default=5,
)
+argParser.add_argument(
+ '--max_checkpoint_factor',
+ type=float,
+ default=1.,
+ help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
+ )
# argParser.add_argument(
# '--output_file',
# type=str,
@@ -77,16 +90,19 @@ argParser.add_argument(
)
args = argParser.parse_args()
-dataset_baseheight = 10
-
if args.verbose:
logger.setLevel(logging.DEBUG)
-def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
+def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
lift_pen = 1
- min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
- abs_x = start_x - min_x
- abs_y = start_y - min_y
+ min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
+ factor_x = max_width / (max_x-min_x)
+ factor_y = max_height / (max_y-min_y)
+ # assuming both < 1
+ factor = factor_y if factor_y < factor_x else factor_x
+
+ abs_x = start_x - min_x * factor
+ abs_y = start_y - min_y * factor
p = "M%s,%s " % (abs_x, abs_y)
# p = "M%s,%s " % (0, 0)
command = "m"
@@ -97,8 +113,8 @@ def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
command = "l"
else:
command = ""
- x = float(strokes[i,0])/factor
- y = float(strokes[i,1])/factor
+ x = float(strokes[i,0])*factor
+ y = float(strokes[i,1])*factor
lift_pen = strokes[i, 2]
p += f"{command}{x:.5},{y:.5} "
the_color = "black"
@@ -160,7 +176,7 @@ sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
-# loads the weights from checkpoint into our model
+# loads the (latest) weights from checkpoint into our model
load_checkpoint(sess, args.model_dir)
def encode(input_strokes):
@@ -171,7 +187,7 @@ def encode(input_strokes):
strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)]
print(seq_len)
- draw_strokes(to_normal_strokes(np.array(strokes)))
+ # draw_strokes(to_normal_strokes(np.array(strokes)))
print(np.array([strokes]).shape)
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
@@ -186,27 +202,67 @@ def decode(z_input=None, temperature=0.1, factor=0.2):
strokes = to_normal_strokes(sample_strokes)
return strokes
+def getCheckpoints(model_dir):
+ files = glob.glob(os.path.join(model_dir,'vector-*.meta'))
+ checkpoints = []
+ for file in files:
+ checkpoints.append(int(file.split('-')[-1][:-5]))
+ return sorted(checkpoints)
+
+currentCheckpoint = None
+def loadCheckpoint(model_dir, nr):
+ if nr == currentCheckpoint:
+ return
+ # loads the intermediate weights from checkpoint into our model
+ saver = tf.train.Saver(tf.global_variables())
+ # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
+ saver.restore(sess, os.path.join(model_dir, f"vector-{nr}"))
+
dims = (args.width, args.height)
width = int(re.findall('\d+',args.width)[0])*10
height = int(re.findall('\d+',args.height)[0])*10
-# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
+padding = 20
+dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
+checkpoints = getCheckpoints(args.model_dir)
item_count = args.rows*args.columns
-factor = dataset_baseheight/ (height/args.rows)
+# factor = dataset_baseheight/ (height/args.rows)
+
+# initialize z
+target_stroke = test_set.strokes[args.target_sample]
+target_z = encode(target_stroke)
+
+max_width = (width-(padding*(args.columns+1))) / args.columns
+max_height = (height-(padding*(args.rows+1))) / args.rows
+
with tqdm(total=item_count) as pbar:
for row in range(args.rows):
- start_y = row * (float(height) / args.rows)
- row_temp = .1+row * (1./args.rows)
+ #find the top left point for the strokes
+ min_y = row * (float(height-2*padding) / args.rows) + padding
for column in range(args.columns):
- strokes = decode(temperature=row_temp)
- fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
- draw_strokes(strokes, svg_filename=fn)
+ min_x = column * (float(width-2*padding) / args.columns) + padding
+ item = row*args.columns + column
+ checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
+ checkpoint = checkpoints[checkpoint_idx]
+ loadCheckpoint(args.model_dir, checkpoint)
+ strokes = decode(target_z, temperature=1)
+ path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
+ dwg.add(path)
+ # draw_strokes(strokes, svg_filename=fn)
+ # for row in range(args.rows):
+ # start_y = row * (float(height) / args.rows)
+ # row_temp = .1+row * (1./args.rows)
+ # for column in range(args.columns):
+ # strokes = decode(temperature=row_temp)
+ # fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
+ # draw_strokes(strokes, svg_filename=fn)
+
# current_nr = row * args.columns + column
# temp = .01+current_nr * (1./item_count)
# start_x = column * (float(width) / args.columns)
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
# dwg.add(path)
pbar.update()
-# dwg.save()
+dwg.save()