Split paths over various groups and allow last item to be purely a sample
This commit is contained in:
parent
8ac94700cc
commit
fa87808bb5
1 changed files with 110 additions and 27 deletions
137
create_card.py
137
create_card.py
|
@ -71,17 +71,43 @@ argParser.add_argument(
|
||||||
type=int,
|
type=int,
|
||||||
default=5,
|
default=5,
|
||||||
)
|
)
|
||||||
|
argParser.add_argument(
|
||||||
|
'--column_padding',
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
|
argParser.add_argument(
|
||||||
|
'--page_margin',
|
||||||
|
type=int,
|
||||||
|
default=70,
|
||||||
|
)
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--max_checkpoint_factor',
|
'--max_checkpoint_factor',
|
||||||
type=float,
|
type=float,
|
||||||
default=1.,
|
default=1.,
|
||||||
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
|
help='If there are too many checkpoints that create smooth outcomes, limit that with this factor'
|
||||||
)
|
)
|
||||||
# argParser.add_argument(
|
argParser.add_argument(
|
||||||
# '--output_file',
|
'--split_paths',
|
||||||
# type=str,
|
action='store_true',
|
||||||
# default='card.svg',
|
help='If a stroke contains multiple steps/paths, split these into separate paths/strokes'
|
||||||
# )
|
)
|
||||||
|
argParser.add_argument(
|
||||||
|
'--nr_of_paths',
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help='If split_paths is given, this number defines the number of groups the paths should be split over'
|
||||||
|
)
|
||||||
|
argParser.add_argument(
|
||||||
|
'--last_is_target',
|
||||||
|
action='store_true',
|
||||||
|
help='Last item not generated, but as the given target'
|
||||||
|
)
|
||||||
|
argParser.add_argument(
|
||||||
|
'--last_in_group',
|
||||||
|
action='store_true',
|
||||||
|
help='If set, put the last rendition into a separate group'
|
||||||
|
)
|
||||||
argParser.add_argument(
|
argParser.add_argument(
|
||||||
'--verbose',
|
'--verbose',
|
||||||
'-v',
|
'-v',
|
||||||
|
@ -121,6 +147,55 @@ def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height):
|
||||||
stroke_width = 1
|
stroke_width = 1
|
||||||
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
|
return dwg.path(p).stroke(the_color,stroke_width).fill("none")
|
||||||
|
|
||||||
|
def splitStrokes(strokes):
|
||||||
|
"""
|
||||||
|
turn [[x,y,0],[x,y,1],[x,y,0]] into [[[x,y,0],[x,y,1]], [[x,y,0]]]
|
||||||
|
"""
|
||||||
|
subStrokes = []
|
||||||
|
for stroke in strokes:
|
||||||
|
subStrokes.append(stroke)
|
||||||
|
if stroke[2] == 1 and subStrokes:
|
||||||
|
yield subStrokes
|
||||||
|
subStrokes = []
|
||||||
|
if subStrokes:
|
||||||
|
yield subStrokes
|
||||||
|
|
||||||
|
def strokesToSplitPaths(dwg, strokes, start_x, start_y, max_width, max_height):
|
||||||
|
"""
|
||||||
|
Stroke not to a single path but each lift pen to a next path
|
||||||
|
"""
|
||||||
|
|
||||||
|
lift_pen = 1
|
||||||
|
min_x, max_x, min_y, max_y = get_bounds(strokes, 1)
|
||||||
|
factor_x = max_width / (max_x-min_x)
|
||||||
|
factor_y = max_height / (max_y-min_y)
|
||||||
|
# assuming both < 1
|
||||||
|
factor = factor_y if factor_y < factor_x else factor_x
|
||||||
|
|
||||||
|
abs_x = start_x - min_x * factor
|
||||||
|
abs_y = start_y - min_y * factor
|
||||||
|
paths = []
|
||||||
|
for path in splitStrokes(strokes):
|
||||||
|
if len(path)<2:
|
||||||
|
logger.warning(f"Too few strokes to create path: {path}")
|
||||||
|
continue
|
||||||
|
for i in xrange(len(path)):
|
||||||
|
# i += 1 #first item is the move, which we already got above
|
||||||
|
x = float(path[i][0])*factor
|
||||||
|
y = float(path[i][1])*factor
|
||||||
|
abs_x += x
|
||||||
|
abs_y += y
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
p = "M%s,%s l" % (abs_x, abs_y)
|
||||||
|
else:
|
||||||
|
p += f"{x:.5},{y:.5} "
|
||||||
|
the_color = "black"
|
||||||
|
stroke_width = 1
|
||||||
|
paths.append(dwg.path(p).stroke(the_color,stroke_width).fill("none"))
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
# little function that displays vector images and saves them to .svg
|
# little function that displays vector images and saves them to .svg
|
||||||
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
|
def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'):
|
||||||
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
|
tf.gfile.MakeDirs(os.path.dirname(svg_filename))
|
||||||
|
@ -221,8 +296,12 @@ def loadCheckpoint(model_dir, nr):
|
||||||
dims = (args.width, args.height)
|
dims = (args.width, args.height)
|
||||||
width = int(re.findall('\d+',args.width)[0])*10
|
width = int(re.findall('\d+',args.width)[0])*10
|
||||||
height = int(re.findall('\d+',args.height)[0])*10
|
height = int(re.findall('\d+',args.height)[0])*10
|
||||||
padding = 20
|
# padding = 20
|
||||||
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
|
dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}")
|
||||||
|
requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0)
|
||||||
|
dwgGroups = [svgwrite.container.Group(id=f"g{i}") for i in range(requiredGroups)]
|
||||||
|
for group in dwgGroups:
|
||||||
|
dwg.add(group)
|
||||||
|
|
||||||
checkpoints = getCheckpoints(args.model_dir)
|
checkpoints = getCheckpoints(args.model_dir)
|
||||||
item_count = args.rows*args.columns
|
item_count = args.rows*args.columns
|
||||||
|
@ -233,36 +312,40 @@ item_count = args.rows*args.columns
|
||||||
target_stroke = test_set.strokes[args.target_sample]
|
target_stroke = test_set.strokes[args.target_sample]
|
||||||
target_z = encode(target_stroke)
|
target_z = encode(target_stroke)
|
||||||
|
|
||||||
max_width = (width-(padding*(args.columns+1))) / args.columns
|
max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1))) / args.columns
|
||||||
max_height = (height-(padding*(args.rows+1))) / args.rows
|
max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows
|
||||||
|
|
||||||
|
|
||||||
with tqdm(total=item_count) as pbar:
|
with tqdm(total=item_count) as pbar:
|
||||||
for row in range(args.rows):
|
for row in range(args.rows):
|
||||||
#find the top left point for the strokes
|
#find the top left point for the strokes
|
||||||
min_y = row * (float(height-2*padding) / args.rows) + padding
|
min_y = row * (max_height + args.column_padding) + args.page_margin
|
||||||
for column in range(args.columns):
|
for column in range(args.columns):
|
||||||
min_x = column * (float(width-2*padding) / args.columns) + padding
|
min_x = column * (max_width + args.column_padding) + args.page_margin
|
||||||
item = row*args.columns + column
|
item = row*args.columns + column
|
||||||
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
|
checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints))
|
||||||
checkpoint = checkpoints[checkpoint_idx]
|
checkpoint = checkpoints[checkpoint_idx]
|
||||||
loadCheckpoint(args.model_dir, checkpoint)
|
loadCheckpoint(args.model_dir, checkpoint)
|
||||||
strokes = decode(target_z, temperature=1)
|
|
||||||
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
|
||||||
dwg.add(path)
|
|
||||||
# draw_strokes(strokes, svg_filename=fn)
|
|
||||||
# for row in range(args.rows):
|
|
||||||
# start_y = row * (float(height) / args.rows)
|
|
||||||
# row_temp = .1+row * (1./args.rows)
|
|
||||||
# for column in range(args.columns):
|
|
||||||
# strokes = decode(temperature=row_temp)
|
|
||||||
# fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg')
|
|
||||||
# draw_strokes(strokes, svg_filename=fn)
|
|
||||||
|
|
||||||
# current_nr = row * args.columns + column
|
isLast = (row == args.rows-1 and column == args.columns-1)
|
||||||
# temp = .01+current_nr * (1./item_count)
|
|
||||||
# start_x = column * (float(width) / args.columns)
|
if isLast and args.last_is_target:
|
||||||
# path = strokesToPath(dwg, strokes, factor, start_x, start_y)
|
strokes = target_stroke
|
||||||
# dwg.add(path)
|
else:
|
||||||
|
strokes = decode(target_z, temperature=1)
|
||||||
|
|
||||||
|
if args.last_in_group and isLast:
|
||||||
|
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
||||||
|
dwgGroups[-1].add(path)
|
||||||
|
elif args.split_paths:
|
||||||
|
paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height)
|
||||||
|
i = 0
|
||||||
|
for path in paths:
|
||||||
|
group = dwgGroups[i % args.nr_of_paths]
|
||||||
|
i+=1
|
||||||
|
group.add(path)
|
||||||
|
else:
|
||||||
|
path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height)
|
||||||
|
dwgGroups[0].add(path)
|
||||||
|
|
||||||
pbar.update()
|
pbar.update()
|
||||||
dwg.save()
|
dwg.save()
|
||||||
|
|
Loading…
Reference in a new issue