import pycocotools.coco import argparse import logging import tqdm import urllib.request import os import svgwrite logging.basicConfig(level=logging.INFO) logger = logging.getLogger("coco") argParser = argparse.ArgumentParser(description='Create shape SVG\'s') argParser.add_argument( '--annotations', type=str, default='../../datasets/COCO/annotations/instances_train2017.json' ) argParser.add_argument( '--output', type=str, help='Output directory' ) argParser.add_argument( '--category_id', type=int, help='Category id' ) args = argParser.parse_args() logger.info(f"Load {args.annotations}") coco = pycocotools.coco.COCO(args.annotations) def createSVGForImage(img_id, filename): dimensions = (coco.imgs[img_id]['width'], coco.imgs[img_id]['height']) dwg = svgwrite.Drawing(filename, size=dimensions) for i, annotation in enumerate(coco.imgToAnns[img_id]): if 'counts' in annotation['segmentation']: # skip RLE masks continue; addAnnotationToSVG(dwg, annotation) dwg.save() def addAnnotationToSVG(dwg, annotation): segmentation = annotation['segmentation'] p = "" for shape in segmentation: # segmentation can have multiple shapes p = f"M{shape[0]},{shape[1]} L " r = len(shape) / 2 - 1 for i in range(int(r)): p += f"{shape[(i+1)*2]},{shape[(i+1)*2+1]} " cat = coco.cats[annotation['category_id']] shape_classes = ' '.join([ f"super_{cat['supercategory']}", f"cat_{cat['name']}" ]) dwg.add(dwg.path(p, class_=shape_classes)) category = coco.cats[args.category_id] logger.info(f"Using {len(coco.cats)} categories") logger.debug(coco.cats) logger.info(f"Limit to: {category}") dirname = os.path.join(args.output, category['name']) if not os.path.exists(dirname): os.mkdir(dirname) for img_id in tqdm.tqdm(coco.catToImgs[args.category_id]): fn = os.path.join(dirname, f"{img_id}.svg") createSVGForImage(img_id, fn) logger.info("Done")