Split paths over various groups and allow last item to be purely a sample

This commit is contained in:
Ruben van de Ven 2019-08-26 11:55:26 +02:00
parent 8ac94700cc
commit fa87808bb5

View file

@ -71,17 +71,43 @@ argParser.add_argument(
type=int,
default=5,
)
argParser.add_argument(
'--column_padding',
type=int,
default=10,
)
argParser.add_argument(
'--page_margin',
type=int,
default=70,
)
argParser.add_argument(
'--max_checkpoint_factor',
type=float,
default=1.,
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
)
# argParser.add_argument(
# '--output_file',
# type=str,
# default='card.svg',
# )
argParser.add_argument(
'--split_paths',
action='store_true',
help='If a stroke contains multiple steps/paths, split these into separate paths/strokes'
)
argParser.add_argument(
'--nr_of_paths',
type=int,
default=3,
help='If split_paths is given, this number defines the number of groups the paths should be split over'
)
argParser.add_argument(
'--last_is_target',
action='store_true',
help='Last item not generated, but as the given target'
)
argParser.add_argument(
'--last_in_group',
action='store_true',
help='If set, put the last rendition into a separate group'
)
argParser.add_argument(
'--verbose',
'-v',
@ -121,6 +147,55 @@ def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
stroke_width = 1
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
def splitStrokes(strokes):
"""
turn [[x,y,0],[x,y,1],[x,y,0]] into [[[x,y,0],[x,y,1]], [[x,y,0]]]
"""
subStrokes = []
for stroke in strokes:
subStrokes.append(stroke)
if stroke[2] == 1 and subStrokes:
yield subStrokes
subStrokes = []
if subStrokes:
yield subStrokes
def strokesToSplitPaths(dwg, strokes, start_x, start_y, max_width, max_height):
"""
Stroke not to a single path but each lift pen to a next path
"""
lift_pen = 1
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
factor_x = max_width / (max_x-min_x)
factor_y = max_height / (max_y-min_y)
# assuming both < 1
factor = factor_y if factor_y < factor_x else factor_x
abs_x = start_x - min_x * factor
abs_y = start_y - min_y * factor
paths = []
for path in splitStrokes(strokes):
if len(path)<2:
logger.warning(f"Too few strokes to create path: {path}")
continue
for i in xrange(len(path)):
# i += 1 #first item is the move, which we already got above
x = float(path[i][0])*factor
y = float(path[i][1])*factor
abs_x += x
abs_y += y
if i == 0:
p = "M%s,%s l" % (abs_x, abs_y)
else:
p += f"{x:.5},{y:.5} "
the_color = "black"
stroke_width = 1
paths.append(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
return paths
# little function that displays vector images and saves them to .svg
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
@ -221,8 +296,12 @@ def loadCheckpoint(model_dir, nr):
dims = (args.width, args.height)
width = int(re.findall('\d+',args.width)[0])*10
height = int(re.findall('\d+',args.height)[0])*10
padding = 20
# padding = 20
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0)
dwgGroups = [svgwrite.container.Group(id=f"g{i}") for i in range(requiredGroups)]
for group in dwgGroups:
dwg.add(group)
checkpoints = getCheckpoints(args.model_dir)
item_count = args.rows*args.columns
@ -233,36 +312,40 @@ item_count = args.rows*args.columns
target_stroke = test_set.strokes[args.target_sample]
target_z = encode(target_stroke)
max_width = (width-(padding*(args.columns+1))) / args.columns
max_height = (height-(padding*(args.rows+1))) / args.rows
max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1))) / args.columns
max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows
with tqdm(total=item_count) as pbar:
for row in range(args.rows):
#find the top left point for the strokes
min_y = row * (float(height-2*padding) / args.rows) + padding
min_y = row * (max_height + args.column_padding) + args.page_margin
for column in range(args.columns):
min_x = column * (float(width-2*padding) / args.columns) + padding
min_x = column * (max_width + args.column_padding) + args.page_margin
item = row*args.columns + column
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
checkpoint = checkpoints[checkpoint_idx]
loadCheckpoint(args.model_dir, checkpoint)
strokes = decode(target_z, temperature=1)
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwg.add(path)
# draw_strokes(strokes, svg_filename=fn)
# for row in range(args.rows):
# start_y = row * (float(height) / args.rows)
# row_temp = .1+row * (1./args.rows)
# for column in range(args.columns):
# strokes = decode(temperature=row_temp)
# fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
# draw_strokes(strokes, svg_filename=fn)
# current_nr = row * args.columns + column
# temp = .01+current_nr * (1./item_count)
# start_x = column * (float(width) / args.columns)
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
# dwg.add(path)
isLast = (row == args.rows-1 and column == args.columns-1)
if isLast and args.last_is_target:
strokes = target_stroke
else:
strokes = decode(target_z, temperature=1)
if args.last_in_group and isLast:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwgGroups[-1].add(path)
elif args.split_paths:
paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height)
i = 0
for path in paths:
group = dwgGroups[i % args.nr_of_paths]
i+=1
group.add(path)
else:
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
dwgGroups[0].add(path)
pbar.update()
dwg.save()