coco/create_shapes.py

81 lines
2.1 KiB
Python

import pycocotools.coco
import argparse
import logging
import tqdm
import urllib.request
import os
import svgwrite
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("coco")
argParser = argparse.ArgumentParser(description='Create shape SVG\'s')
argParser.add_argument(
'--annotations',
type=str,
default='../../datasets/COCO/annotations/instances_train2017.json'
)
argParser.add_argument(
'--output',
type=str,
help='Output directory'
)
argParser.add_argument(
'--category_id',
type=int,
help='Category id'
)
args = argParser.parse_args()
logger.info(f"Load {args.annotations}")
coco = pycocotools.coco.COCO(args.annotations)
def createSVGForImage(img_id, filename):
dimensions = (coco.imgs[img_id]['width'], coco.imgs[img_id]['height'])
dwg = svgwrite.Drawing(filename, size=dimensions)
for i, annotation in enumerate(coco.imgToAnns[img_id]):
if 'counts' in annotation['segmentation']:
# skip RLE masks
continue;
addAnnotationToSVG(dwg, annotation)
dwg.save()
def addAnnotationToSVG(dwg, annotation):
segmentation = annotation['segmentation']
p = ""
for shape in segmentation:
# segmentation can have multiple shapes
p = f"M{shape[0]},{shape[1]} L "
r = len(shape) / 2 - 1
for i in range(int(r)):
p += f"{shape[(i+1)*2]},{shape[(i+1)*2+1]} "
cat = coco.cats[annotation['category_id']]
shape_classes = ' '.join([
f"super_{cat['supercategory']}",
f"cat_{cat['name']}"
])
dwg.add(dwg.path(p, class_=shape_classes))
category = coco.cats[args.category_id]
logger.info(f"Using {len(coco.cats)} categories")
logger.debug(coco.cats)
logger.info(f"Limit to: {category}")
dirname = os.path.join(args.output, category['name'])
if not os.path.exists(dirname):
os.mkdir(dirname)
for img_id in tqdm.tqdm(coco.catToImgs[args.category_id]):
fn = os.path.join(dirname, f"{img_id}.svg")
createSVGForImage(img_id, fn)
logger.info("Done")