81 lines
2.1 KiB
Python
81 lines
2.1 KiB
Python
|
import pycocotools.coco
|
||
|
import argparse
|
||
|
import logging
|
||
|
import tqdm
|
||
|
import urllib.request
|
||
|
import os
|
||
|
import svgwrite
|
||
|
|
||
|
|
||
|
logging.basicConfig(level=logging.INFO)
|
||
|
logger = logging.getLogger("coco")
|
||
|
|
||
|
argParser = argparse.ArgumentParser(description='Create shape SVG\'s')
|
||
|
argParser.add_argument(
|
||
|
'--annotations',
|
||
|
type=str,
|
||
|
default='../../datasets/COCO/annotations/instances_train2017.json'
|
||
|
)
|
||
|
argParser.add_argument(
|
||
|
'--output',
|
||
|
type=str,
|
||
|
help='Output directory'
|
||
|
)
|
||
|
argParser.add_argument(
|
||
|
'--category_id',
|
||
|
type=int,
|
||
|
help='Category id'
|
||
|
)
|
||
|
args = argParser.parse_args()
|
||
|
|
||
|
|
||
|
logger.info(f"Load {args.annotations}")
|
||
|
coco = pycocotools.coco.COCO(args.annotations)
|
||
|
|
||
|
def createSVGForImage(img_id, filename):
|
||
|
dimensions = (coco.imgs[img_id]['width'], coco.imgs[img_id]['height'])
|
||
|
dwg = svgwrite.Drawing(filename, size=dimensions)
|
||
|
for i, annotation in enumerate(coco.imgToAnns[img_id]):
|
||
|
if 'counts' in annotation['segmentation']:
|
||
|
# skip RLE masks
|
||
|
continue;
|
||
|
|
||
|
addAnnotationToSVG(dwg, annotation)
|
||
|
dwg.save()
|
||
|
|
||
|
def addAnnotationToSVG(dwg, annotation):
|
||
|
segmentation = annotation['segmentation']
|
||
|
|
||
|
|
||
|
p = ""
|
||
|
for shape in segmentation:
|
||
|
# segmentation can have multiple shapes
|
||
|
p = f"M{shape[0]},{shape[1]} L "
|
||
|
r = len(shape) / 2 - 1
|
||
|
for i in range(int(r)):
|
||
|
p += f"{shape[(i+1)*2]},{shape[(i+1)*2+1]} "
|
||
|
|
||
|
cat = coco.cats[annotation['category_id']]
|
||
|
shape_classes = ' '.join([
|
||
|
f"super_{cat['supercategory']}",
|
||
|
f"cat_{cat['name']}"
|
||
|
])
|
||
|
dwg.add(dwg.path(p, class_=shape_classes))
|
||
|
|
||
|
category = coco.cats[args.category_id]
|
||
|
logger.info(f"Using {len(coco.cats)} categories")
|
||
|
logger.debug(coco.cats)
|
||
|
|
||
|
logger.info(f"Limit to: {category}")
|
||
|
|
||
|
|
||
|
dirname = os.path.join(args.output, category['name'])
|
||
|
|
||
|
if not os.path.exists(dirname):
|
||
|
os.mkdir(dirname)
|
||
|
|
||
|
for img_id in tqdm.tqdm(coco.catToImgs[args.category_id]):
|
||
|
fn = os.path.join(dirname, f"{img_id}.svg")
|
||
|
createSVGForImage(img_id, fn)
|
||
|
|
||
|
logger.info("Done")
|