150 lines
4.7 KiB
Python
150 lines
4.7 KiB
Python
from rdp import rdp
|
|
import xml.etree.ElementTree as ET
|
|
from svg.path import parse_path
|
|
import numpy as np
|
|
import logging
|
|
import argparse
|
|
import os
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger('dataset')
|
|
|
|
argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test')
|
|
argParser.add_argument(
|
|
'--src',
|
|
type=str,
|
|
default='./data/naam',
|
|
help=''
|
|
)
|
|
argParser.add_argument(
|
|
'--dataset_dir',
|
|
type=str,
|
|
default='./datasets/naam',
|
|
)
|
|
argParser.add_argument(
|
|
'--multiply',
|
|
type=int,
|
|
default=2,
|
|
help="If you dont have enough items, automatically multiply it so there's at least 100 per set.",
|
|
)
|
|
argParser.add_argument(
|
|
'--verbose',
|
|
'-v',
|
|
action='store_true',
|
|
help='Debug logging'
|
|
)
|
|
args = argParser.parse_args()
|
|
|
|
if args.verbose:
|
|
logger.setLevel(logging.DEBUG)
|
|
|
|
|
|
def getStroke3FromSvg(filename):
|
|
"""
|
|
Get drawing as stroke-3 format.
|
|
Gets each group as different drawing
|
|
points as as [dx, dy, pen state]
|
|
"""
|
|
logger.debug(f"Read {filename}")
|
|
s = ET.parse(filename)
|
|
root = s.getroot()
|
|
groups = root.findall("{http://www.w3.org/2000/svg}g")
|
|
|
|
sketches = []
|
|
for group in groups:
|
|
svg_paths = group.findall("{http://www.w3.org/2000/svg}path")
|
|
paths = []
|
|
|
|
min_x = None
|
|
min_y = None
|
|
max_y = None
|
|
|
|
for p in svg_paths:
|
|
path = parse_path(p.get("d"))
|
|
|
|
points = [[point.end.real, point.end.imag] for point in path]
|
|
x_points = np.array([p[0] for p in points])
|
|
y_points = np.array([p[1] for p in points])
|
|
|
|
if min_x is None:
|
|
min_x = min(x_points)
|
|
min_y = min(y_points)
|
|
max_y = max(y_points)
|
|
else:
|
|
min_x = min(min_x, min(x_points))
|
|
min_y = min(min_y, min(y_points))
|
|
max_y = max(max_y, max(y_points))
|
|
|
|
points = np.array([[x_points[i], y_points[i]] for i in range(len(points))])
|
|
paths.append(points)
|
|
|
|
# scale normalize & crop
|
|
scale = 512 / (max_y - min_y)
|
|
|
|
prev_x = None
|
|
prev_y = None
|
|
|
|
strokes = []
|
|
for path in paths:
|
|
path[:,0] -= min_x
|
|
path[:,1] -= min_y
|
|
path *= scale
|
|
#simplify using Ramer-Douglas-Peucker., see https://github.com/tensorflow/magenta/tree/master/magenta/models/sketch_rnn
|
|
if(len(path) > 800):
|
|
logger.debug(f'simplify {len(path)} factor 3.5')
|
|
path = rdp(path, epsilon=3.5)
|
|
logger.debug(f'\tnow {len(path)}')
|
|
if(len(path) > 490):
|
|
logger.debug(f'simplify {len(path)} factor 3')
|
|
path = rdp(path, epsilon=3)
|
|
logger.debug(f'\tnow {len(path)}')
|
|
if(len(path) > 300):
|
|
logger.debug(f'simplify {len(path)} factor 2')
|
|
path = rdp(path, epsilon=2.0)
|
|
logger.debug(f'\tnow {len(path)}')
|
|
for point in path:
|
|
if prev_x is not None and prev_y is not None:
|
|
strokes.append([int(point[0] - prev_x), int(point[1] - prev_y), 0])
|
|
|
|
prev_x = point[0]
|
|
prev_y = point[1]
|
|
|
|
# mark lifting of pen
|
|
strokes[-1][2] = 1
|
|
|
|
logger.debug(f"Paths: {len(strokes)}")
|
|
# strokes = np.array(strokes, dtype=np.int16)
|
|
sketches.append(strokes)
|
|
|
|
return sketches
|
|
|
|
def main():
|
|
sketches = {}
|
|
for dirName, subdirList, fileList in os.walk(args.src):
|
|
for fname in fileList:
|
|
filename = os.path.join(dirName, fname)
|
|
className = fname[:-4].rstrip('0123456789.- ')
|
|
if not className in sketches:
|
|
sketches[className] = []
|
|
sketches[className].extend(getStroke3FromSvg(filename))
|
|
|
|
|
|
for className in sketches:
|
|
filename = os.path.join(args.dataset_dir, className + '.npz')
|
|
itemsForClass = len(sketches[className])
|
|
if itemsForClass < 100:
|
|
logger.info(f"Loop to have at least 100 for class {className} (now {itemsForClass})")
|
|
extras = []
|
|
for i in range(100 - itemsForClass):
|
|
extras.append(sketches[className][i % itemsForClass])
|
|
sketches[className].extend(extras)
|
|
logger.debug(f"Now {len(sketches[className])} samples for {className}")
|
|
sets = sketches[className]
|
|
# exit()
|
|
np.savez_compressed(filename, train=sets, valid=sets, test=sets)
|
|
logger.info(f"Saved {len(sets)} samples in {filename}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|