From fa87808bb5f52cfb19d7a75fd6fb7a9198c55911 Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Mon, 26 Aug 2019 11:55:26 +0200 Subject: [PATCH] Split paths over various groups and allow last item to be purely a sample --- create_card.py | 137 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 110 insertions(+), 27 deletions(-) diff --git a/create_card.py b/create_card.py index 7e02c57..29e14f6 100644 --- a/create_card.py +++ b/create_card.py @@ -71,17 +71,43 @@ argParser.add_argument( type=int, default=5, ) +argParser.add_argument( + '--column_padding', + type=int, + default=10, + ) +argParser.add_argument( + '--page_margin', + type=int, + default=70, + ) argParser.add_argument( '--max_checkpoint_factor', type=float, default=1., help='If there are too many checkpoints that create smooth outcomes, limit that with this factor' ) -# argParser.add_argument( -# '--output_file', -# type=str, -# default='card.svg', -# ) +argParser.add_argument( + '--split_paths', + action='store_true', + help='If a stroke contains multiple steps/paths, split these into separate paths/strokes' + ) +argParser.add_argument( + '--nr_of_paths', + type=int, + default=3, + help='If split_paths is given, this number defines the number of groups the paths should be split over' + ) +argParser.add_argument( + '--last_is_target', + action='store_true', + help='Last item not generated, but as the given target' + ) +argParser.add_argument( + '--last_in_group', + action='store_true', + help='If set, put the last rendition into a separate group' + ) argParser.add_argument( '--verbose', '-v', @@ -121,6 +147,55 @@ def strokesToPath(dwg, strokes, start_x, start_y, max_width, max_height): stroke_width = 1 return dwg.path(p).stroke(the_color,stroke_width).fill("none") +def splitStrokes(strokes): + """ + turn [[x,y,0],[x,y,1],[x,y,0]] into [[[x,y,0],[x,y,1]], [[x,y,0]]] + """ + subStrokes = [] + for stroke in strokes: + subStrokes.append(stroke) + if stroke[2] == 1 and subStrokes: + yield subStrokes + subStrokes = [] + if subStrokes: + yield subStrokes + +def strokesToSplitPaths(dwg, strokes, start_x, start_y, max_width, max_height): + """ + Stroke not to a single path but each lift pen to a next path + """ + + lift_pen = 1 + min_x, max_x, min_y, max_y = get_bounds(strokes, 1) + factor_x = max_width / (max_x-min_x) + factor_y = max_height / (max_y-min_y) + # assuming both < 1 + factor = factor_y if factor_y < factor_x else factor_x + + abs_x = start_x - min_x * factor + abs_y = start_y - min_y * factor + paths = [] + for path in splitStrokes(strokes): + if len(path)<2: + logger.warning(f"Too few strokes to create path: {path}") + continue + for i in xrange(len(path)): + # i += 1 #first item is the move, which we already got above + x = float(path[i][0])*factor + y = float(path[i][1])*factor + abs_x += x + abs_y += y + + if i == 0: + p = "M%s,%s l" % (abs_x, abs_y) + else: + p += f"{x:.5},{y:.5} " + the_color = "black" + stroke_width = 1 + paths.append(dwg.path(p).stroke(the_color,stroke_width).fill("none")) + return paths + + # little function that displays vector images and saves them to .svg def draw_strokes(strokes, factor=0.2, svg_filename = '/tmp/sketch_rnn/svg/sample.svg'): tf.gfile.MakeDirs(os.path.dirname(svg_filename)) @@ -221,8 +296,12 @@ def loadCheckpoint(model_dir, nr): dims = (args.width, args.height) width = int(re.findall('\d+',args.width)[0])*10 height = int(re.findall('\d+',args.height)[0])*10 -padding = 20 +# padding = 20 dwg = svgwrite.Drawing(args.output_file, size=dims, viewBox=f"0 0 {width} {height}") +requiredGroups = (args.nr_of_paths if args.split_paths else 1) + (1 if args.last_in_group else 0) +dwgGroups = [svgwrite.container.Group(id=f"g{i}") for i in range(requiredGroups)] +for group in dwgGroups: + dwg.add(group) checkpoints = getCheckpoints(args.model_dir) item_count = args.rows*args.columns @@ -233,36 +312,40 @@ item_count = args.rows*args.columns target_stroke = test_set.strokes[args.target_sample] target_z = encode(target_stroke) -max_width = (width-(padding*(args.columns+1))) / args.columns -max_height = (height-(padding*(args.rows+1))) / args.rows - +max_width = (width - args.page_margin*2 - (args.column_padding*(args.columns-1))) / args.columns +max_height = (height - args.page_margin*2 -(args.column_padding*(args.rows-1))) / args.rows with tqdm(total=item_count) as pbar: for row in range(args.rows): #find the top left point for the strokes - min_y = row * (float(height-2*padding) / args.rows) + padding + min_y = row * (max_height + args.column_padding) + args.page_margin for column in range(args.columns): - min_x = column * (float(width-2*padding) / args.columns) + padding + min_x = column * (max_width + args.column_padding) + args.page_margin item = row*args.columns + column checkpoint_idx = math.floor(float(item)*args.max_checkpoint_factor/item_count * len(checkpoints)) checkpoint = checkpoints[checkpoint_idx] loadCheckpoint(args.model_dir, checkpoint) - strokes = decode(target_z, temperature=1) - path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) - dwg.add(path) - # draw_strokes(strokes, svg_filename=fn) - # for row in range(args.rows): - # start_y = row * (float(height) / args.rows) - # row_temp = .1+row * (1./args.rows) - # for column in range(args.columns): - # strokes = decode(temperature=row_temp) - # fn = os.path.join(args.output_dir, f'generated-{row_temp:.2d}-{column}.svg') - # draw_strokes(strokes, svg_filename=fn) - # current_nr = row * args.columns + column - # temp = .01+current_nr * (1./item_count) - # start_x = column * (float(width) / args.columns) - # path = strokesToPath(dwg, strokes, factor, start_x, start_y) - # dwg.add(path) + isLast = (row == args.rows-1 and column == args.columns-1) + + if isLast and args.last_is_target: + strokes = target_stroke + else: + strokes = decode(target_z, temperature=1) + + if args.last_in_group and isLast: + path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) + dwgGroups[-1].add(path) + elif args.split_paths: + paths = strokesToSplitPaths(dwg, strokes, min_x, min_y, max_width, max_height) + i = 0 + for path in paths: + group = dwgGroups[i % args.nr_of_paths] + i+=1 + group.add(path) + else: + path = strokesToPath(dwg, strokes, min_x, min_y, max_width, max_height) + dwgGroups[0].add(path) + pbar.update() dwg.save()