birthcard/create_dataset.py

151 lines
4.7 KiB
Python
Raw Normal View History

2019-08-25 15:19:27 +00:00
from rdp import rdp
import xml.etree.ElementTree as ET
from svg.path import parse_path
import numpy as np
import logging
import argparse
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('dataset')
argParser = argparse.ArgumentParser(description='Create dataset from SVG. We do not mind overfitting, so training==validation==test')
argParser.add_argument(
'--src',
type=str,
default='./data/naam',
help=''
)
argParser.add_argument(
'--dataset_dir',
type=str,
default='./datasets/naam',
)
argParser.add_argument(
'--multiply',
type=int,
default=2,
help="If you dont have enough items, automatically multiply it so there's at least 100 per set.",
)
argParser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Debug logging'
)
args = argParser.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
def getStroke3FromSvg(filename):
"""
Get drawing as stroke-3 format.
Gets each group as different drawing
points as as [dx, dy, pen state]
"""
logger.debug(f"Read {filename}")
s = ET.parse(filename)
root = s.getroot()
groups = root.findall("{http://www.w3.org/2000/svg}g")
sketches = []
for group in groups:
svg_paths = group.findall("{http://www.w3.org/2000/svg}path")
paths = []
min_x = None
min_y = None
max_y = None
for p in svg_paths:
path = parse_path(p.get("d"))
points = [[point.end.real, point.end.imag] for point in path]
x_points = np.array([p[0] for p in points])
y_points = np.array([p[1] for p in points])
if min_x is None:
min_x = min(x_points)
min_y = min(y_points)
max_y = max(y_points)
else:
min_x = min(min_x, min(x_points))
min_y = min(min_y, min(y_points))
max_y = max(max_y, max(y_points))
points = np.array([[x_points[i], y_points[i]] for i in range(len(points))])
paths.append(points)
# scale normalize & crop
scale = 512 / (max_y - min_y)
prev_x = None
prev_y = None
strokes = []
for path in paths:
path[:,0] -= min_x
path[:,1] -= min_y
path *= scale
#simplify using Ramer-Douglas-Peucker., see https://github.com/tensorflow/magenta/tree/master/magenta/models/sketch_rnn
if(len(path) > 800):
logger.debug(f'simplify {len(path)} factor 3.5')
path = rdp(path, epsilon=3.5)
logger.debug(f'\tnow {len(path)}')
if(len(path) > 490):
logger.debug(f'simplify {len(path)} factor 3')
path = rdp(path, epsilon=3)
logger.debug(f'\tnow {len(path)}')
if(len(path) > 300):
logger.debug(f'simplify {len(path)} factor 2')
path = rdp(path, epsilon=2.0)
logger.debug(f'\tnow {len(path)}')
for point in path:
if prev_x is not None and prev_y is not None:
strokes.append([int(point[0] - prev_x), int(point[1] - prev_y), 0])
prev_x = point[0]
prev_y = point[1]
# mark lifting of pen
strokes[-1][2] = 1
logger.debug(f"Paths: {len(strokes)}")
# strokes = np.array(strokes, dtype=np.int16)
sketches.append(strokes)
return sketches
def main():
sketches = {}
for dirName, subdirList, fileList in os.walk(args.src):
for fname in fileList:
filename = os.path.join(dirName, fname)
className = fname[:-4].rstrip('0123456789.- ')
if not className in sketches:
sketches[className] = []
sketches[className].extend(getStroke3FromSvg(filename))
for className in sketches:
filename = os.path.join(args.dataset_dir, className + '.npz')
itemsForClass = len(sketches[className])
if itemsForClass < 100:
logger.info(f"Loop to have at least 100 for class {className} (now {itemsForClass})")
extras = []
for i in range(100 - itemsForClass):
extras.append(sketches[className][i % itemsForClass])
sketches[className].extend(extras)
logger.debug(f"Now {len(sketches[className])} samples for {className}")
sets = sketches[className]
# exit()
np.savez_compressed(filename, train=sets, valid=sets, test=sets)
logger.info(f"Saved {len(sets)} samples in {filename}")
if __name__ == '__main__':
main()