Create simple card
This commit is contained in:
parent
35a398c42e
commit
8ac94700cc
2 changed files with 83 additions and 27 deletions
File diff suppressed because one or more lines are too long
104
create_card.py
104
create_card.py
|
@ -15,6 +15,8 @@ import argparse
|
|||
import svgwrite
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
import glob
|
||||
import math
|
||||
|
||||
|
||||
# import our command line tools
|
||||
|
@ -31,28 +33,33 @@ argParser = argparse.ArgumentParser(description='Create postcard')
|
|||
argParser.add_argument(
|
||||
'--data_dir',
|
||||
type=str,
|
||||
default='./datasets/naam3',
|
||||
default='./datasets/naam4',
|
||||
help=''
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--model_dir',
|
||||
type=str,
|
||||
default='./models/naam3',
|
||||
default='./models/naam4',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--output_dir',
|
||||
'--output_file',
|
||||
type=str,
|
||||
default='generated/naam3',
|
||||
default='generated/naam4.svg',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--target_sample',
|
||||
type=int,
|
||||
default=202,
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--width',
|
||||
type=str,
|
||||
default='90mm',
|
||||
default='100mm',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--height',
|
||||
type=str,
|
||||
default='150mm',
|
||||
default='190mm',
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--rows',
|
||||
|
@ -64,6 +71,12 @@ argParser.add_argument(
|
|||
type=int,
|
||||
default=5,
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--max_checkpoint_factor',
|
||||
type=float,
|
||||
default=1.,
|
||||
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
|
||||
)
|
||||
# argParser.add_argument(
|
||||
# '--output_file',
|
||||
# type=str,
|
||||
|
@ -77,16 +90,19 @@ argParser.add_argument(
|
|||
)
|
||||
args = argParser.parse_args()
|
||||
|
||||
dataset_baseheight = 10
|
||||
|
||||
if args.verbose:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
|
||||
def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
|
||||
lift_pen = 1
|
||||
min_x, max_x, min_y, max_y = get_bounds(strokes, factor)
|
||||
abs_x = start_x - min_x
|
||||
abs_y = start_y - min_y
|
||||
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
|
||||
factor_x = max_width / (max_x-min_x)
|
||||
factor_y = max_height / (max_y-min_y)
|
||||
# assuming both < 1
|
||||
factor = factor_y if factor_y < factor_x else factor_x
|
||||
|
||||
abs_x = start_x - min_x * factor
|
||||
abs_y = start_y - min_y * factor
|
||||
p = "M%s,%s " % (abs_x, abs_y)
|
||||
# p = "M%s,%s " % (0, 0)
|
||||
command = "m"
|
||||
|
@ -97,8 +113,8 @@ def strokesToPath(dwg, strokes, factor=.2, start_x=25, start_y=25):
|
|||
command = "l"
|
||||
else:
|
||||
command = ""
|
||||
x = float(strokes[i,0])/factor
|
||||
y = float(strokes[i,1])/factor
|
||||
x = float(strokes[i,0])*factor
|
||||
y = float(strokes[i,1])*factor
|
||||
lift_pen = strokes[i, 2]
|
||||
p += f"{command}{x:.5},{y:.5} "
|
||||
the_color = "black"
|
||||
|
@ -160,7 +176,7 @@ sample_model = Model(sample_hps_model, reuse=True)
|
|||
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
# loads the weights from checkpoint into our model
|
||||
# loads the (latest) weights from checkpoint into our model
|
||||
load_checkpoint(sess, args.model_dir)
|
||||
|
||||
def encode(input_strokes):
|
||||
|
@ -171,7 +187,7 @@ def encode(input_strokes):
|
|||
strokes.insert(0, [0, 0, 1, 0, 0])
|
||||
seq_len = [len(input_strokes)]
|
||||
print(seq_len)
|
||||
draw_strokes(to_normal_strokes(np.array(strokes)))
|
||||
# draw_strokes(to_normal_strokes(np.array(strokes)))
|
||||
print(np.array([strokes]).shape)
|
||||
return sess.run(eval_model.batch_z, feed_dict={eval_model.input_data: [strokes], eval_model.sequence_lengths: seq_len})[0]
|
||||
|
||||
|
@ -186,27 +202,67 @@ def decode(z_input=None, temperature=0.1, factor=0.2):
|
|||
strokes = to_normal_strokes(sample_strokes)
|
||||
return strokes
|
||||
|
||||
def getCheckpoints(model_dir):
|
||||
files = glob.glob(os.path.join(model_dir,'vector-*.meta'))
|
||||
checkpoints = []
|
||||
for file in files:
|
||||
checkpoints.append(int(file.split('-')[-1][:-5]))
|
||||
return sorted(checkpoints)
|
||||
|
||||
currentCheckpoint = None
|
||||
def loadCheckpoint(model_dir, nr):
|
||||
if nr == currentCheckpoint:
|
||||
return
|
||||
# loads the intermediate weights from checkpoint into our model
|
||||
saver = tf.train.Saver(tf.global_variables())
|
||||
# tf.logging.info('Loading model %s.', ckpt.model_checkpoint_path)
|
||||
saver.restore(sess, os.path.join(model_dir, f"vector-{nr}"))
|
||||
|
||||
dims = (args.width, args.height)
|
||||
width = int(re.findall('\d+',args.width)[0])*10
|
||||
height = int(re.findall('\d+',args.height)[0])*10
|
||||
# dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
|
||||
padding = 20
|
||||
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
|
||||
|
||||
checkpoints = getCheckpoints(args.model_dir)
|
||||
item_count = args.rows*args.columns
|
||||
|
||||
factor = dataset_baseheight/ (height/args.rows)
|
||||
# factor = dataset_baseheight/ (height/args.rows)
|
||||
|
||||
# initialize z
|
||||
target_stroke = test_set.strokes[args.target_sample]
|
||||
target_z = encode(target_stroke)
|
||||
|
||||
max_width = (width-(padding*(args.columns+1))) / args.columns
|
||||
max_height = (height-(padding*(args.rows+1))) / args.rows
|
||||
|
||||
|
||||
with tqdm(total=item_count) as pbar:
|
||||
for row in range(args.rows):
|
||||
start_y = row * (float(height) / args.rows)
|
||||
row_temp = .1+row * (1./args.rows)
|
||||
#find the top left point for the strokes
|
||||
min_y = row * (float(height-2*padding) / args.rows) + padding
|
||||
for column in range(args.columns):
|
||||
strokes = decode(temperature=row_temp)
|
||||
fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
|
||||
draw_strokes(strokes, svg_filename=fn)
|
||||
min_x = column * (float(width-2*padding) / args.columns) + padding
|
||||
item = row*args.columns + column
|
||||
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
|
||||
checkpoint = checkpoints[checkpoint_idx]
|
||||
loadCheckpoint(args.model_dir, checkpoint)
|
||||
strokes = decode(target_z, temperature=1)
|
||||
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
||||
dwg.add(path)
|
||||
# draw_strokes(strokes, svg_filename=fn)
|
||||
# for row in range(args.rows):
|
||||
# start_y = row * (float(height) / args.rows)
|
||||
# row_temp = .1+row * (1./args.rows)
|
||||
# for column in range(args.columns):
|
||||
# strokes = decode(temperature=row_temp)
|
||||
# fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
|
||||
# draw_strokes(strokes, svg_filename=fn)
|
||||
|
||||
# current_nr = row * args.columns + column
|
||||
# temp = .01+current_nr * (1./item_count)
|
||||
# start_x = column * (float(width) / args.columns)
|
||||
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
|
||||
# dwg.add(path)
|
||||
pbar.update()
|
||||
# dwg.save()
|
||||
dwg.save()
|
||||
|
|
Loading…
Reference in a new issue