Split paths over various groups and allow last item to be purely a sample
This commit is contained in:
parent
8ac94700cc
commit
fa87808bb5
1 changed files with 110 additions and 27 deletions
137
create_card.py
137
create_card.py
|
@ -71,17 +71,43 @@ argParser.add_argument(
|
|||
type=int,
|
||||
default=5,
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--column_padding',
|
||||
type=int,
|
||||
default=10,
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--page_margin',
|
||||
type=int,
|
||||
default=70,
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--max_checkpoint_factor',
|
||||
type=float,
|
||||
default=1.,
|
||||
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
|
||||
)
|
||||
# argParser.add_argument(
|
||||
# '--output_file',
|
||||
# type=str,
|
||||
# default='card.svg',
|
||||
# )
|
||||
argParser.add_argument(
|
||||
'--split_paths',
|
||||
action='store_true',
|
||||
help='If a stroke contains multiple steps/paths, split these into separate paths/strokes'
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--nr_of_paths',
|
||||
type=int,
|
||||
default=3,
|
||||
help='If split_paths is given, this number defines the number of groups the paths should be split over'
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--last_is_target',
|
||||
action='store_true',
|
||||
help='Last item not generated, but as the given target'
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--last_in_group',
|
||||
action='store_true',
|
||||
help='If set, put the last rendition into a separate group'
|
||||
)
|
||||
argParser.add_argument(
|
||||
'--verbose',
|
||||
'-v',
|
||||
|
@ -121,6 +147,55 @@ def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
|
|||
stroke_width = 1
|
||||
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
|
||||
|
||||
def splitStrokes(strokes):
|
||||
"""
|
||||
turn [[x,y,0],[x,y,1],[x,y,0]] into [[[x,y,0],[x,y,1]], [[x,y,0]]]
|
||||
"""
|
||||
subStrokes = []
|
||||
for stroke in strokes:
|
||||
subStrokes.append(stroke)
|
||||
if stroke[2] == 1 and subStrokes:
|
||||
yield subStrokes
|
||||
subStrokes = []
|
||||
if subStrokes:
|
||||
yield subStrokes
|
||||
|
||||
def strokesToSplitPaths(dwg, strokes, start_x, start_y, max_width, max_height):
|
||||
"""
|
||||
Stroke not to a single path but each lift pen to a next path
|
||||
"""
|
||||
|
||||
lift_pen = 1
|
||||
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
|
||||
factor_x = max_width / (max_x-min_x)
|
||||
factor_y = max_height / (max_y-min_y)
|
||||
# assuming both < 1
|
||||
factor = factor_y if factor_y < factor_x else factor_x
|
||||
|
||||
abs_x = start_x - min_x * factor
|
||||
abs_y = start_y - min_y * factor
|
||||
paths = []
|
||||
for path in splitStrokes(strokes):
|
||||
if len(path)<2:
|
||||
logger.warning(f"Too few strokes to create path: {path}")
|
||||
continue
|
||||
for i in xrange(len(path)):
|
||||
# i += 1 #first item is the move, which we already got above
|
||||
x = float(path[i][0])*factor
|
||||
y = float(path[i][1])*factor
|
||||
abs_x += x
|
||||
abs_y += y
|
||||
|
||||
if i == 0:
|
||||
p = "M%s,%s l" % (abs_x, abs_y)
|
||||
else:
|
||||
p += f"{x:.5},{y:.5} "
|
||||
the_color = "black"
|
||||
stroke_width = 1
|
||||
paths.append(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
|
||||
return paths
|
||||
|
||||
|
||||
# little function that displays vector images and saves them to .svg
|
||||
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
|
||||
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
|
||||
|
@ -221,8 +296,12 @@ def loadCheckpoint(model_dir, nr):
|
|||
dims = (args.width, args.height)
|
||||
width = int(re.findall('\d+',args.width)[0])*10
|
||||
height = int(re.findall('\d+',args.height)[0])*10
|
||||
padding = 20
|
||||
# padding = 20
|
||||
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
|
||||
requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0)
|
||||
dwgGroups = [svgwrite.container.Group(id=f"g{i}") for i in range(requiredGroups)]
|
||||
for group in dwgGroups:
|
||||
dwg.add(group)
|
||||
|
||||
checkpoints = getCheckpoints(args.model_dir)
|
||||
item_count = args.rows*args.columns
|
||||
|
@ -233,36 +312,40 @@ item_count = args.rows*args.columns
|
|||
target_stroke = test_set.strokes[args.target_sample]
|
||||
target_z = encode(target_stroke)
|
||||
|
||||
max_width = (width-(padding*(args.columns+1))) / args.columns
|
||||
max_height = (height-(padding*(args.rows+1))) / args.rows
|
||||
|
||||
max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1))) / args.columns
|
||||
max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows
|
||||
|
||||
with tqdm(total=item_count) as pbar:
|
||||
for row in range(args.rows):
|
||||
#find the top left point for the strokes
|
||||
min_y = row * (float(height-2*padding) / args.rows) + padding
|
||||
min_y = row * (max_height + args.column_padding) + args.page_margin
|
||||
for column in range(args.columns):
|
||||
min_x = column * (float(width-2*padding) / args.columns) + padding
|
||||
min_x = column * (max_width + args.column_padding) + args.page_margin
|
||||
item = row*args.columns + column
|
||||
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
|
||||
checkpoint = checkpoints[checkpoint_idx]
|
||||
loadCheckpoint(args.model_dir, checkpoint)
|
||||
strokes = decode(target_z, temperature=1)
|
||||
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
||||
dwg.add(path)
|
||||
# draw_strokes(strokes, svg_filename=fn)
|
||||
# for row in range(args.rows):
|
||||
# start_y = row * (float(height) / args.rows)
|
||||
# row_temp = .1+row * (1./args.rows)
|
||||
# for column in range(args.columns):
|
||||
# strokes = decode(temperature=row_temp)
|
||||
# fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
|
||||
# draw_strokes(strokes, svg_filename=fn)
|
||||
|
||||
# current_nr = row * args.columns + column
|
||||
# temp = .01+current_nr * (1./item_count)
|
||||
# start_x = column * (float(width) / args.columns)
|
||||
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
|
||||
# dwg.add(path)
|
||||
isLast = (row == args.rows-1 and column == args.columns-1)
|
||||
|
||||
if isLast and args.last_is_target:
|
||||
strokes = target_stroke
|
||||
else:
|
||||
strokes = decode(target_z, temperature=1)
|
||||
|
||||
if args.last_in_group and isLast:
|
||||
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
||||
dwgGroups[-1].add(path)
|
||||
elif args.split_paths:
|
||||
paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height)
|
||||
i = 0
|
||||
for path in paths:
|
||||
group = dwgGroups[i % args.nr_of_paths]
|
||||
i+=1
|
||||
group.add(path)
|
||||
else:
|
||||
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
||||
dwgGroups[0].add(path)
|
||||
|
||||
pbar.update()
|
||||
dwg.save()
|
||||
|
|
Loading…
Reference in a new issue