diff --git a/Sketch_RNN.ipynb b/Sketch_RNN.ipynb index dd9e1e9..b17cc9a 100644 --- a/Sketch_RNN.ipynb +++ b/Sketch_RNN.ipynb @@ -546,7 +546,7 @@ }, { "cell_type": "code", - "execution_count": 189, + "execution_count": 201, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -560,7 +560,7 @@ { "data": { "image/svg+xml": [ - "" + "" ], "text/plain": [ "" @@ -573,7 +573,7 @@ "source": [ "# get a sample drawing from the test set, and render it to .svg\n", "stroke = test_set.random_sample()\n", - "stroke=test_set.strokes[202]\n", + "stroke=test_set.strokes[252]\n", "draw_strokes(stroke)" ] }, diff --git a/create_card.py b/create_card.py index 6bb85dd..7e02c57 100644 --- a/create_card.py +++ b/create_card.py @@ -15,6 +15,8 @@ import argparse import svgwrite from tqdm import tqdm import re +import glob +import math # import our command line tools @@ -31,28 +33,33 @@ argParser = argparse.ArgumentParser(description='Create postcard') argParser.add_argument( '--data_dir', type=str, - default='./datasets/naam3', + default='./datasets/naam4', help='' ) argParser.add_argument( '--model_dir', type=str, - default='./models/naam3', + default='./models/naam4', ) argParser.add_argument( - '--output_dir', + '--output_file', type=str, - default='generated/naam3', + default='generated/naam4.svg', + ) +argParser.add_argument( + '--target_sample', + type=int, + default=202, ) argParser.add_argument( '--width', type=str, - default='90mm', + default='100mm', ) argParser.add_argument( '--height', type=str, - default='150mm', + default='190mm', ) argParser.add_argument( '--rows', @@ -64,6 +71,12 @@ argParser.add_argument( type=int, default=5, ) +argParser.add_argument( + '--max_checkpoint_factor', + type=float, + default=1., + help='If there are too many checkpoints that create smooth outcomes, limit that with this factor' + ) # argParser.add_argument( # '--output_file', # type=str, @@ -77,16 +90,19 @@ argParser.add_argument( ) args = argParser.parse_args() -dataset_baseheight = 10 - if args.verbose: logger.setLevel(logging.DEBUG) -def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25): +def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height): lift_pen = 1 - min_x, max_x, min_y, max_y = get_bounds(strokes, factor) - abs_x = start_x - min_x - abs_y = start_y - min_y + min_x, max_x, min_y, max_y = get_bounds(strokes, 1) + factor_x = max_width / (max_x-min_x) + factor_y = max_height / (max_y-min_y) + # assuming both < 1 + factor = factor_y if factor_y < factor_x else factor_x + + abs_x = start_x - min_x * factor + abs_y = start_y - min_y * factor p = "M%s,%s " % (abs_x, abs_y) # p = "M%s,%s " % (0, 0) command = "m" @@ -97,8 +113,8 @@ def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25): command = "l" else: command = "" - x = float(strokes[i,0])/factor - y = float(strokes[i,1])/factor + x = float(strokes[i,0])*factor + y = float(strokes[i,1])*factor lift_pen = strokes[i, 2] p += f"{command}{x:.5},{y:.5} " the_color = "black" @@ -160,7 +176,7 @@ sample_model = Model(sample_hps_model, reuse=True) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) -# loads the weights from checkpoint into our model +# loads the (latest) weights from checkpoint into our model load_checkpoint(sess, args.model_dir) def encode(input_strokes): @@ -171,7 +187,7 @@ def encode(input_strokes): strokes.insert(0, [0, 0, 1, 0, 0]) seq_len = [len(input_strokes)] print(seq_len) - draw_strokes(to_normal_strokes(np.array(strokes))) + # draw_strokes(to_normal_strokes(np.array(strokes))) print(np.array([strokes]).shape) return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0] @@ -186,27 +202,67 @@ def decode(z_input=None, temperature=0.1, factor=0.2): strokes = to_normal_strokes(sample_strokes) return strokes +def getCheckpoints(model_dir): + files = glob.glob(os.path.join(model_dir,'vector-*.meta')) + checkpoints = [] + for file in files: + checkpoints.append(int(file.split('-')[-1][:-5])) + return sorted(checkpoints) + +currentCheckpoint = None +def loadCheckpoint(model_dir, nr): + if nr == currentCheckpoint: + return + # loads the intermediate weights from checkpoint into our model + saver = tf.train.Saver(tf.global_variables()) + # tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path) + saver.restore(sess, os.path.join(model_dir, f"vector-{nr}")) + dims = (args.width, args.height) width = int(re.findall('\d+',args.width)[0])*10 height = int(re.findall('\d+',args.height)[0])*10 -# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}") +padding = 20 +dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}") +checkpoints = getCheckpoints(args.model_dir) item_count = args.rows*args.columns -factor = dataset_baseheight/ (height/args.rows) +# factor = dataset_baseheight/ (height/args.rows) + +# initialize z +target_stroke = test_set.strokes[args.target_sample] +target_z = encode(target_stroke) + +max_width = (width-(padding*(args.columns+1))) / args.columns +max_height = (height-(padding*(args.rows+1))) / args.rows + with tqdm(total=item_count) as pbar: for row in range(args.rows): - start_y = row * (float(height) / args.rows) - row_temp = .1+row * (1./args.rows) + #find the top left point for the strokes + min_y = row * (float(height-2*padding) / args.rows) + padding for column in range(args.columns): - strokes = decode(temperature=row_temp) - fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg') - draw_strokes(strokes, svg_filename=fn) + min_x = column * (float(width-2*padding) / args.columns) + padding + item = row*args.columns + column + checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints)) + checkpoint = checkpoints[checkpoint_idx] + loadCheckpoint(args.model_dir, checkpoint) + strokes = decode(target_z, temperature=1) + path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) + dwg.add(path) + # draw_strokes(strokes, svg_filename=fn) + # for row in range(args.rows): + # start_y = row * (float(height) / args.rows) + # row_temp = .1+row * (1./args.rows) + # for column in range(args.columns): + # strokes = decode(temperature=row_temp) + # fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg') + # draw_strokes(strokes, svg_filename=fn) + # current_nr = row * args.columns + column # temp = .01+current_nr * (1./item_count) # start_x = column * (float(width) / args.columns) # path = strokesToPath(dwg, strokes, factor, start_x, start_y) # dwg.add(path) pbar.update() -# dwg.save() +dwg.save()