Create simple card
This commit is contained in:
parent
35a398c42e
commit
8ac94700cc
2 changed files with 83 additions and 27 deletions
File diff suppressed because one or more lines are too long
104
create_card.py
104
create_card.py
|
@ -15,6 +15,8 @@ import argparse
|
||||||
import svgwrite
|
import svgwrite
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import re
|
import re
|
||||||
|
import glob
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
# import our command line tools
|
# import our command line tools
|
||||||
|
@ -31,28 +33,33 @@ argParser = argparse.ArgumentParser(description='Create postcard')
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--data_dir',
|
'--data_dir',
|
||||||
type=str,
|
type=str,
|
||||||
default='./datasets/naam3',
|
default='./datasets/naam4',
|
||||||
help=''
|
help=''
|
||||||
)
|
)
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--model_dir',
|
'--model_dir',
|
||||||
type=str,
|
type=str,
|
||||||
default='./models/naam3',
|
default='./models/naam4',
|
||||||
)
|
)
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--output_dir',
|
'--output_file',
|
||||||
type=str,
|
type=str,
|
||||||
default='generated/naam3',
|
default='generated/naam4.svg',
|
||||||
|
)
|
||||||
|
argParser.add_argument(
|
||||||
|
'--target_sample',
|
||||||
|
type=int,
|
||||||
|
default=202,
|
||||||
)
|
)
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--width',
|
'--width',
|
||||||
type=str,
|
type=str,
|
||||||
default='90mm',
|
default='100mm',
|
||||||
)
|
)
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--height',
|
'--height',
|
||||||
type=str,
|
type=str,
|
||||||
default='150mm',
|
default='190mm',
|
||||||
)
|
)
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--rows',
|
'--rows',
|
||||||
|
@ -64,6 +71,12 @@ argParser.add_argument(
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
)
|
)
|
||||||
|
argParser.add_argument(
|
||||||
|
'--max_checkpoint_factor',
|
||||||
|
type=float,
|
||||||
|
default=1.,
|
||||||
|
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
|
||||||
|
)
|
||||||
# argParser.add_argument(
|
# argParser.add_argument(
|
||||||
# '--output_file',
|
# '--output_file',
|
||||||
# type=str,
|
# type=str,
|
||||||
|
@ -77,16 +90,19 @@ argParser.add_argument(
|
||||||
)
|
)
|
||||||
args = argParser.parse_args()
|
args = argParser.parse_args()
|
||||||
|
|
||||||
dataset_baseheight = 10
|
|
||||||
|
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
|
def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
|
||||||
lift_pen = 1
|
lift_pen = 1
|
||||||
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
|
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
|
||||||
abs_x = start_x - min_x
|
factor_x = max_width / (max_x-min_x)
|
||||||
abs_y = start_y - min_y
|
factor_y = max_height / (max_y-min_y)
|
||||||
|
# assuming both < 1
|
||||||
|
factor = factor_y if factor_y < factor_x else factor_x
|
||||||
|
|
||||||
|
abs_x = start_x - min_x * factor
|
||||||
|
abs_y = start_y - min_y * factor
|
||||||
p = "M%s,%s " % (abs_x, abs_y)
|
p = "M%s,%s " % (abs_x, abs_y)
|
||||||
# p = "M%s,%s " % (0, 0)
|
# p = "M%s,%s " % (0, 0)
|
||||||
command = "m"
|
command = "m"
|
||||||
|
@ -97,8 +113,8 @@ def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
|
||||||
command = "l"
|
command = "l"
|
||||||
else:
|
else:
|
||||||
command = ""
|
command = ""
|
||||||
x = float(strokes[i,0])/factor
|
x = float(strokes[i,0])*factor
|
||||||
y = float(strokes[i,1])/factor
|
y = float(strokes[i,1])*factor
|
||||||
lift_pen = strokes[i, 2]
|
lift_pen = strokes[i, 2]
|
||||||
p += f"{command}{x:.5},{y:.5} "
|
p += f"{command}{x:.5},{y:.5} "
|
||||||
the_color = "black"
|
the_color = "black"
|
||||||
|
@ -160,7 +176,7 @@ sample_model = Model(sample_hps_model, reuse=True)
|
||||||
|
|
||||||
sess = tf.InteractiveSession()
|
sess = tf.InteractiveSession()
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
# loads the weights from checkpoint into our model
|
# loads the (latest) weights from checkpoint into our model
|
||||||
load_checkpoint(sess, args.model_dir)
|
load_checkpoint(sess, args.model_dir)
|
||||||
|
|
||||||
def encode(input_strokes):
|
def encode(input_strokes):
|
||||||
|
@ -171,7 +187,7 @@ def encode(input_strokes):
|
||||||
strokes.insert(0, [0, 0, 1, 0, 0])
|
strokes.insert(0, [0, 0, 1, 0, 0])
|
||||||
seq_len = [len(input_strokes)]
|
seq_len = [len(input_strokes)]
|
||||||
print(seq_len)
|
print(seq_len)
|
||||||
draw_strokes(to_normal_strokes(np.array(strokes)))
|
# draw_strokes(to_normal_strokes(np.array(strokes)))
|
||||||
print(np.array([strokes]).shape)
|
print(np.array([strokes]).shape)
|
||||||
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
|
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
|
||||||
|
|
||||||
|
@ -186,27 +202,67 @@ def decode(z_input=None, temperature=0.1, factor=0.2):
|
||||||
strokes = to_normal_strokes(sample_strokes)
|
strokes = to_normal_strokes(sample_strokes)
|
||||||
return strokes
|
return strokes
|
||||||
|
|
||||||
|
def getCheckpoints(model_dir):
|
||||||
|
files = glob.glob(os.path.join(model_dir,'vector-*.meta'))
|
||||||
|
checkpoints = []
|
||||||
|
for file in files:
|
||||||
|
checkpoints.append(int(file.split('-')[-1][:-5]))
|
||||||
|
return sorted(checkpoints)
|
||||||
|
|
||||||
|
currentCheckpoint = None
|
||||||
|
def loadCheckpoint(model_dir, nr):
|
||||||
|
if nr == currentCheckpoint:
|
||||||
|
return
|
||||||
|
# loads the intermediate weights from checkpoint into our model
|
||||||
|
saver = tf.train.Saver(tf.global_variables())
|
||||||
|
# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
|
||||||
|
saver.restore(sess, os.path.join(model_dir, f"vector-{nr}"))
|
||||||
|
|
||||||
dims = (args.width, args.height)
|
dims = (args.width, args.height)
|
||||||
width = int(re.findall('\d+',args.width)[0])*10
|
width = int(re.findall('\d+',args.width)[0])*10
|
||||||
height = int(re.findall('\d+',args.height)[0])*10
|
height = int(re.findall('\d+',args.height)[0])*10
|
||||||
# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
|
padding = 20
|
||||||
|
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
|
||||||
|
|
||||||
|
checkpoints = getCheckpoints(args.model_dir)
|
||||||
item_count = args.rows*args.columns
|
item_count = args.rows*args.columns
|
||||||
|
|
||||||
factor = dataset_baseheight/ (height/args.rows)
|
# factor = dataset_baseheight/ (height/args.rows)
|
||||||
|
|
||||||
|
# initialize z
|
||||||
|
target_stroke = test_set.strokes[args.target_sample]
|
||||||
|
target_z = encode(target_stroke)
|
||||||
|
|
||||||
|
max_width = (width-(padding*(args.columns+1))) / args.columns
|
||||||
|
max_height = (height-(padding*(args.rows+1))) / args.rows
|
||||||
|
|
||||||
|
|
||||||
with tqdm(total=item_count) as pbar:
|
with tqdm(total=item_count) as pbar:
|
||||||
for row in range(args.rows):
|
for row in range(args.rows):
|
||||||
start_y = row * (float(height) / args.rows)
|
#find the top left point for the strokes
|
||||||
row_temp = .1+row * (1./args.rows)
|
min_y = row * (float(height-2*padding) / args.rows) + padding
|
||||||
for column in range(args.columns):
|
for column in range(args.columns):
|
||||||
strokes = decode(temperature=row_temp)
|
min_x = column * (float(width-2*padding) / args.columns) + padding
|
||||||
fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
|
item = row*args.columns + column
|
||||||
draw_strokes(strokes, svg_filename=fn)
|
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
|
||||||
|
checkpoint = checkpoints[checkpoint_idx]
|
||||||
|
loadCheckpoint(args.model_dir, checkpoint)
|
||||||
|
strokes = decode(target_z, temperature=1)
|
||||||
|
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
||||||
|
dwg.add(path)
|
||||||
|
# draw_strokes(strokes, svg_filename=fn)
|
||||||
|
# for row in range(args.rows):
|
||||||
|
# start_y = row * (float(height) / args.rows)
|
||||||
|
# row_temp = .1+row * (1./args.rows)
|
||||||
|
# for column in range(args.columns):
|
||||||
|
# strokes = decode(temperature=row_temp)
|
||||||
|
# fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
|
||||||
|
# draw_strokes(strokes, svg_filename=fn)
|
||||||
|
|
||||||
# current_nr = row * args.columns + column
|
# current_nr = row * args.columns + column
|
||||||
# temp = .01+current_nr * (1./item_count)
|
# temp = .01+current_nr * (1./item_count)
|
||||||
# start_x = column * (float(width) / args.columns)
|
# start_x = column * (float(width) / args.columns)
|
||||||
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
|
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
|
||||||
# dwg.add(path)
|
# dwg.add(path)
|
||||||
pbar.update()
|
pbar.update()
|
||||||
# dwg.save()
|
dwg.save()
|
||||||
|
|
Loading…
Reference in a new issue