81 lines
No EOL
2.1 KiB
Python
81 lines
No EOL
2.1 KiB
Python
import pycocotools.coco
|
|
import argparse
|
|
import logging
|
|
import tqdm
|
|
import urllib.request
|
|
import os
|
|
import svgwrite
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("coco")
|
|
|
|
argParser = argparse.ArgumentParser(description='Create shape SVG\'s')
|
|
argParser.add_argument(
|
|
'--annotations',
|
|
type=str,
|
|
default='../../datasets/COCO/annotations/instances_train2017.json'
|
|
)
|
|
argParser.add_argument(
|
|
'--output',
|
|
type=str,
|
|
help='Output directory'
|
|
)
|
|
argParser.add_argument(
|
|
'--category_id',
|
|
type=int,
|
|
help='Category id'
|
|
)
|
|
args = argParser.parse_args()
|
|
|
|
|
|
logger.info(f"Load {args.annotations}")
|
|
coco = pycocotools.coco.COCO(args.annotations)
|
|
|
|
def createSVGForImage(img_id, filename):
|
|
dimensions = (coco.imgs[img_id]['width'], coco.imgs[img_id]['height'])
|
|
dwg = svgwrite.Drawing(filename, size=dimensions)
|
|
for i, annotation in enumerate(coco.imgToAnns[img_id]):
|
|
if 'counts' in annotation['segmentation']:
|
|
# skip RLE masks
|
|
continue;
|
|
|
|
addAnnotationToSVG(dwg, annotation)
|
|
dwg.save()
|
|
|
|
def addAnnotationToSVG(dwg, annotation):
|
|
segmentation = annotation['segmentation']
|
|
|
|
|
|
p = ""
|
|
for shape in segmentation:
|
|
# segmentation can have multiple shapes
|
|
p = f"M{shape[0]},{shape[1]} L "
|
|
r = len(shape) / 2 - 1
|
|
for i in range(int(r)):
|
|
p += f"{shape[(i+1)*2]},{shape[(i+1)*2+1]} "
|
|
|
|
cat = coco.cats[annotation['category_id']]
|
|
shape_classes = ' '.join([
|
|
f"super_{cat['supercategory']}",
|
|
f"cat_{cat['name']}"
|
|
])
|
|
dwg.add(dwg.path(p, class_=shape_classes))
|
|
|
|
category = coco.cats[args.category_id]
|
|
logger.info(f"Using {len(coco.cats)} categories")
|
|
logger.debug(coco.cats)
|
|
|
|
logger.info(f"Limit to: {category}")
|
|
|
|
|
|
dirname = os.path.join(args.output, category['name'])
|
|
|
|
if not os.path.exists(dirname):
|
|
os.mkdir(dirname)
|
|
|
|
for img_id in tqdm.tqdm(coco.catToImgs[args.category_id]):
|
|
fn = os.path.join(dirname, f"{img_id}.svg")
|
|
createSVGForImage(img_id, fn)
|
|
|
|
logger.info("Done") |