You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.1 KiB
81 lines
2.1 KiB
import pycocotools.coco |
|
import argparse |
|
import logging |
|
import tqdm |
|
import urllib.request |
|
import os |
|
import svgwrite |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger("coco") |
|
|
|
argParser = argparse.ArgumentParser(description='Create shape SVG\'s') |
|
argParser.add_argument( |
|
'--annotations', |
|
type=str, |
|
default='../../datasets/COCO/annotations/instances_train2017.json' |
|
) |
|
argParser.add_argument( |
|
'--output', |
|
type=str, |
|
help='Output directory' |
|
) |
|
argParser.add_argument( |
|
'--category_id', |
|
type=int, |
|
help='Category id' |
|
) |
|
args = argParser.parse_args() |
|
|
|
|
|
logger.info(f"Load {args.annotations}") |
|
coco = pycocotools.coco.COCO(args.annotations) |
|
|
|
def createSVGForImage(img_id, filename): |
|
dimensions = (coco.imgs[img_id]['width'], coco.imgs[img_id]['height']) |
|
dwg = svgwrite.Drawing(filename, size=dimensions) |
|
for i, annotation in enumerate(coco.imgToAnns[img_id]): |
|
if 'counts' in annotation['segmentation']: |
|
# skip RLE masks |
|
continue; |
|
|
|
addAnnotationToSVG(dwg, annotation) |
|
dwg.save() |
|
|
|
def addAnnotationToSVG(dwg, annotation): |
|
segmentation = annotation['segmentation'] |
|
|
|
|
|
p = "" |
|
for shape in segmentation: |
|
# segmentation can have multiple shapes |
|
p = f"M{shape[0]},{shape[1]} L " |
|
r = len(shape) / 2 - 1 |
|
for i in range(int(r)): |
|
p += f"{shape[(i+1)*2]},{shape[(i+1)*2+1]} " |
|
|
|
cat = coco.cats[annotation['category_id']] |
|
shape_classes = ' '.join([ |
|
f"super_{cat['supercategory']}", |
|
f"cat_{cat['name']}" |
|
]) |
|
dwg.add(dwg.path(p, class_=shape_classes)) |
|
|
|
category = coco.cats[args.category_id] |
|
logger.info(f"Using {len(coco.cats)} categories") |
|
logger.debug(coco.cats) |
|
|
|
logger.info(f"Limit to: {category}") |
|
|
|
|
|
dirname = os.path.join(args.output, category['name']) |
|
|
|
if not os.path.exists(dirname): |
|
os.mkdir(dirname) |
|
|
|
for img_id in tqdm.tqdm(coco.catToImgs[args.category_id]): |
|
fn = os.path.join(dirname, f"{img_id}.svg") |
|
createSVGForImage(img_id, fn) |
|
|
|
logger.info("Done") |