from rdp import rdp import xml.etree.ElementTree as ET from svg.path import parse_path import numpy as np import logging import argparse import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger('dataset') argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test') argParser.add_argument( '--src', type=str, default='./data/naam', help='' ) argParser.add_argument( '--dataset_dir', type=str, default='./datasets/naam', ) argParser.add_argument( '--multiply', type=int, default=2, help="If you dont have enough items, automatically multiply it so there's at least 100 per set.", ) argParser.add_argument( '--verbose', '-v', action='store_true', help='Debug logging' ) args = argParser.parse_args() if args.verbose: logger.setLevel(logging.DEBUG) def getStroke3FromSvg(filename): """ Get drawing as stroke-3 format. Gets each group as different drawing points as as [dx, dy, pen state] """ logger.debug(f"Read {filename}") s = ET.parse(filename) root = s.getroot() groups = root.findall("{http://www.w3.org/2000/svg}g") sketches = [] for group in groups: svg_paths = group.findall("{http://www.w3.org/2000/svg}path") paths = [] min_x = None min_y = None max_y = None for p in svg_paths: path = parse_path(p.get("d")) points = [[point.end.real, point.end.imag] for point in path] x_points = np.array([p[0] for p in points]) y_points = np.array([p[1] for p in points]) if min_x is None: min_x = min(x_points) min_y = min(y_points) max_y = max(y_points) else: min_x = min(min_x, min(x_points)) min_y = min(min_y, min(y_points)) max_y = max(max_y, max(y_points)) points = np.array([[x_points[i], y_points[i]] for i in range(len(points))]) paths.append(points) # scale normalize & crop scale = 512 / (max_y - min_y) prev_x = None prev_y = None strokes = [] for path in paths: path[:,0] -= min_x path[:,1] -= min_y path *= scale #simplify using Ramer-Douglas-Peucker., see https://github.com/tensorflow/magenta/tree/master/magenta/models/sketch_rnn if(len(path) > 800): logger.debug(f'simplify {len(path)} factor 3.5') path = rdp(path, epsilon=3.5) logger.debug(f'\tnow {len(path)}') if(len(path) > 490): logger.debug(f'simplify {len(path)} factor 3') path = rdp(path, epsilon=3) logger.debug(f'\tnow {len(path)}') if(len(path) > 300): logger.debug(f'simplify {len(path)} factor 2') path = rdp(path, epsilon=2.0) logger.debug(f'\tnow {len(path)}') for point in path: if prev_x is not None and prev_y is not None: strokes.append([int(point[0] - prev_x), int(point[1] - prev_y), 0]) prev_x = point[0] prev_y = point[1] # mark lifting of pen strokes[-1][2] = 1 logger.debug(f"Paths: {len(strokes)}") # strokes = np.array(strokes, dtype=np.int16) sketches.append(strokes) return sketches def main(): sketches = {} for dirName, subdirList, fileList in os.walk(args.src): for fname in fileList: filename = os.path.join(dirName, fname) className = fname[:-4].rstrip('0123456789.- ') if not className in sketches: sketches[className] = [] sketches[className].extend(getStroke3FromSvg(filename)) for className in sketches: filename = os.path.join(args.dataset_dir, className + '.npz') itemsForClass = len(sketches[className]) if itemsForClass < 100: logger.info(f"Loop to have at least 100 for class {className} (now {itemsForClass})") extras = [] for i in range(100 - itemsForClass): extras.append(sketches[className][i % itemsForClass]) sketches[className].extend(extras) logger.debug(f"Now {len(sketches[className])} samples for {className}") sets = sketches[className] # exit() np.savez_compressed(filename, train=sets, valid=sets, test=sets) logger.info(f"Saved {len(sets)} samples in {filename}") if __name__ == '__main__': main()