Create simple card

This commit is contained in:
Ruben van de Ven 2019-08-26 10:40:44 +02:00
parent 35a398c42e
commit 8ac94700cc
2 changed files with 83 additions and 27 deletions

File diff suppressed because one or more lines are too long

View file

@ -15,6 +15,8 @@ import argparse
import svgwrite import svgwrite
from tqdm import tqdm from tqdm import tqdm
import re import re
import glob
import math
# import our command line tools # import our command line tools
@ -31,28 +33,33 @@ argParser = argparse.ArgumentParser(description='Create postcard')
argParser.add_argument( argParser.add_argument(
'--data_dir', '--data_dir',
type=str, type=str,
default='./datasets/naam3', default='./datasets/naam4',
help='' help=''
) )
argParser.add_argument( argParser.add_argument(
'--model_dir', '--model_dir',
type=str, type=str,
default='./models/naam3', default='./models/naam4',
) )
argParser.add_argument( argParser.add_argument(
'--output_dir', '--output_file',
type=str, type=str,
default='generated/naam3', default='generated/naam4.svg',
)
argParser.add_argument(
'--target_sample',
type=int,
default=202,
) )
argParser.add_argument( argParser.add_argument(
'--width', '--width',
type=str, type=str,
default='90mm', default='100mm',
) )
argParser.add_argument( argParser.add_argument(
'--height', '--height',
type=str, type=str,
default='150mm', default='190mm',
) )
argParser.add_argument( argParser.add_argument(
'--rows', '--rows',
@ -64,6 +71,12 @@ argParser.add_argument(
type=int, type=int,
default=5, default=5,
) )
argParser.add_argument(
'--max_checkpoint_factor',
type=float,
default=1.,
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
)
# argParser.add_argument( # argParser.add_argument(
# '--output_file', # '--output_file',
# type=str, # type=str,
@ -77,16 +90,19 @@ argParser.add_argument(
) )
args = argParser.parse_args() args = argParser.parse_args()
dataset_baseheight = 10
if args.verbose: if args.verbose:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25): def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
lift_pen = 1 lift_pen = 1
min_x, max_x, min_y, max_y = get_bounds(strokes, factor) min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
abs_x = start_x - min_x factor_x = max_width / (max_x-min_x)
abs_y = start_y - min_y factor_y = max_height / (max_y-min_y)
# assuming both < 1
factor = factor_y if factor_y < factor_x else factor_x
abs_x = start_x - min_x * factor
abs_y = start_y - min_y * factor
p = "M%s,%s " % (abs_x, abs_y) p = "M%s,%s " % (abs_x, abs_y)
# p = "M%s,%s " % (0, 0) # p = "M%s,%s " % (0, 0)
command = "m" command = "m"
@ -97,8 +113,8 @@ def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
command = "l" command = "l"
else: else:
command = "" command = ""
x = float(strokes[i,0])/factor x = float(strokes[i,0])*factor
y = float(strokes[i,1])/factor y = float(strokes[i,1])*factor
lift_pen = strokes[i, 2] lift_pen = strokes[i, 2]
p += f"{command}{x:.5},{y:.5} " p += f"{command}{x:.5},{y:.5} "
the_color = "black" the_color = "black"
@ -160,7 +176,7 @@ sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession() sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model # loads the (latest) weights from checkpoint into our model
load_checkpoint(sess, args.model_dir) load_checkpoint(sess, args.model_dir)
def encode(input_strokes): def encode(input_strokes):
@ -171,7 +187,7 @@ def encode(input_strokes):
strokes.insert(0, [0, 0, 1, 0, 0]) strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)] seq_len = [len(input_strokes)]
print(seq_len) print(seq_len)
draw_strokes(to_normal_strokes(np.array(strokes))) # draw_strokes(to_normal_strokes(np.array(strokes)))
print(np.array([strokes]).shape) print(np.array([strokes]).shape)
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0] return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
@ -186,27 +202,67 @@ def decode(z_input=None, temperature=0.1, factor=0.2):
strokes = to_normal_strokes(sample_strokes) strokes = to_normal_strokes(sample_strokes)
return strokes return strokes
def getCheckpoints(model_dir):
files = glob.glob(os.path.join(model_dir,'vector-*.meta'))
checkpoints = []
for file in files:
checkpoints.append(int(file.split('-')[-1][:-5]))
return sorted(checkpoints)
currentCheckpoint = None
def loadCheckpoint(model_dir, nr):
if nr == currentCheckpoint:
return
# loads the intermediate weights from checkpoint into our model
saver = tf.train.Saver(tf.global_variables())
# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(model_dir, f"vector-{nr}"))
dims = (args.width, args.height) dims = (args.width, args.height)
width = int(re.findall('\d+',args.width)[0])*10 width = int(re.findall('\d+',args.width)[0])*10
height = int(re.findall('\d+',args.height)[0])*10 height = int(re.findall('\d+',args.height)[0])*10
# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}") padding = 20
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
checkpoints = getCheckpoints(args.model_dir)
item_count = args.rows*args.columns item_count = args.rows*args.columns
factor = dataset_baseheight/ (height/args.rows) # factor = dataset_baseheight/ (height/args.rows)
# initialize z
target_stroke = test_set.strokes[args.target_sample]
target_z = encode(target_stroke)
max_width = (width-(padding*(args.columns+1))) / args.columns
max_height = (height-(padding*(args.rows+1))) / args.rows
with tqdm(total=item_count) as pbar: with tqdm(total=item_count) as pbar:
for row in range(args.rows): for row in range(args.rows):
start_y = row * (float(height) / args.rows) #find the top left point for the strokes
row_temp = .1+row * (1./args.rows) min_y = row * (float(height-2*padding) / args.rows) + padding
for column in range(args.columns): for column in range(args.columns):
strokes = decode(temperature=row_temp) min_x = column * (float(width-2*padding) / args.columns) + padding
fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg') item = row*args.columns + column
draw_strokes(strokes, svg_filename=fn) checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
checkpoint = checkpoints[checkpoint_idx]
loadCheckpoint(args.model_dir, checkpoint)
strokes = decode(target_z, temperature=1)
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwg.add(path)
# draw_strokes(strokes, svg_filename=fn)
# for row in range(args.rows):
# start_y = row * (float(height) / args.rows)
# row_temp = .1+row * (1./args.rows)
# for column in range(args.columns):
# strokes = decode(temperature=row_temp)
# fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
# draw_strokes(strokes, svg_filename=fn)
# current_nr = row * args.columns + column # current_nr = row * args.columns + column
# temp = .01+current_nr * (1./item_count) # temp = .01+current_nr * (1./item_count)
# start_x = column * (float(width) / args.columns) # start_x = column * (float(width) / args.columns)
# path = strokesToPath(dwg, strokes, factor, start_x, start_y) # path = strokesToPath(dwg, strokes, factor, start_x, start_y)
# dwg.add(path) # dwg.add(path)
pbar.update() pbar.update()
# dwg.save() dwg.save()