coco/tools.py

299 lines
11 KiB
Python

import pycocotools.coco
import argparse
import logging
import os
import pprint
import sqlite3
from coco.storage import COCOStorage, Annotation, Segment
import cv2
import mahotas
import subprocess
import tqdm
import numpy as np
import ast
import svgwrite
from svgwrite.extensions import Inkscape
from xml.etree import ElementTree
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("coco")
argParser = argparse.ArgumentParser(description='Create shape SVG\'s')
argParser.add_argument(
'--annotations',
type=str,
default='../../datasets/COCO/annotations/instances_val2017.json'
)
argParser.add_argument(
'--categories',
action='store_true',
help='Show categories'
)
argParser.add_argument(
'--db',
type=str,
metavar='DATABASE',
help='SQLite db filename, will be created if not existing'
)
argParser.add_argument(
'--propagate',
action='store_true',
help='Store annotation data in sqlite db'
)
argParser.add_argument(
'--zerkine',
action='store_true',
help='Find and store annotation Zerkine moments for those that do not have it yet'
)
argParser.add_argument(
'--similar',
type=int,
metavar="ANNOTATION_ID",
help='Find similar shapes for annotation'
)
argParser.add_argument(
'--stickers',
type=str,
metavar="SVG_FILENAME",
help="""
Create an SVG with sticker pages (afterwards convert to EPS: \"for f in *; do echo $f; inkscape -f $f --export-eps $f.eps; done\")
"""
)
args = argParser.parse_args()
logger.info(f"Load {args.annotations}")
coco = pycocotools.coco.COCO(args.annotations)
if args.categories:
cats = {}
for id, cat in coco.cats.items():
if cat['supercategory'] not in cats:
cats[cat['supercategory']] = []
cats[cat['supercategory']].append(cat)
# pp = pprint.PrettyPrinter(indent=4)
pprint.pprint(cats, sort_dicts=False)
storage = None
if args.db:
storage = COCOStorage(args.db)
con = storage.con
if args.propagate:
logger.info("Create categories")
cur = con.cursor()
cur.executemany('INSERT OR IGNORE INTO categories(id, supercategory, name) VALUES (:id, :supercategory, :name)', coco.cats.values())
con.commit()
logger.info("Images...")
cur.executemany('''
INSERT OR IGNORE INTO images(id, flickr_url, coco_url, width, height, date_captured)
VALUES (:id, :flickr_url, :coco_url, :width, :height, :date_captured)
''', coco.imgs.values())
con.commit()
logger.info("Annotations...")
def annotation_generator():
for c in coco.anns.values():
ann = c.copy()
ann['bbox_top'] = ann['bbox'][1]
ann['bbox_left'] = ann['bbox'][0]
ann['bbox_width'] = ann['bbox'][2]
ann['bbox_height'] = ann['bbox'][3]
yield ann
cur.executemany('''
INSERT OR IGNORE INTO annotations(id, image_id, category_id, iscrowd, area, bbox_top, bbox_left, bbox_width, bbox_height)
VALUES (:id, :image_id, :category_id, :iscrowd, :area, :bbox_top, :bbox_left, :bbox_width, :bbox_height)
''', annotation_generator())
con.commit()
logger.info("Segments...")
def segment_generator():
for ann in coco.anns.values():
for i, seg in enumerate(ann['segmentation']):
yield {
'id': ann['id']*10 + i, # create a uniqe segment id, supports max 10 segments per annotation
'annotation_id': ann['id'],
'points': str(seg)[1:-1],
}
cur.executemany('''
INSERT OR IGNORE INTO segments(id, annotation_id, points)
VALUES (:id, :annotation_id, :points)
''', segment_generator())
con.commit()
logger.info("Done...")
if args.zerkine:
nr = storage.countAnnotationsWithoutZerkine()
for i in tqdm.tqdm(range(nr)):
annotation = storage.getAnnotationWithoutZerkine()
normAnn = annotation.getNormalised(100, 100)
filenameRoot = '/tmp/tmp_ann_to_convert'
dwg = normAnn.asSvg(filenameRoot + '.svg', square=True, bg='black')
dwg.save()
# convert to rasterised
subprocess.call([
'inkscape',
'-f', filenameRoot + '.svg',
'-e', filenameRoot + '.png',
],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
# read with opencv
image = cv2.imread(filenameRoot + '.png', cv2.IMREAD_GRAYSCALE)
moments = mahotas.features.zernike_moments(image, 21)
storage.storeZerkineForAnnotation(annotation, moments, delayCommit = True)
if not i % 100:
storage.con.commit()
storage.con.commit()
if args.similar:
# todo find similar
annotation = storage.getAnnotationById(args.similar, withZerkine=True)
dwg = annotation.asSvg(f'tmp/source.svg', square=True)
dwg.save()
shapeA = np.array(annotation.segments[0].points)
annMoments = np.array(annotation.zerkine_moment)
distances = []
print(annotation)
# Stupid, this seems to have been superfluous
# zerkines = storage.getZerkines()
# for zerkine in tqdm.tqdm(zerkines):
# if annotation.id == zerkine['id']:
# continue
#
# diff = annMoments - np.array(Annotation.parseZerkineFromDB(zerkine['zerkine_moment']))
# distance = np.linalg.norm(diff)
# distances.append((zerkine['id'], distance))
anns = storage.getAllAnnotationPoints()
for ann in tqdm.tqdm(anns):
try:
shapeB = np.array(Segment.asCoordinates(ast.literal_eval('['+ann['points']+']')))
# fourth param is require, but according to docs does nothing
distance= cv2.matchShapes(shapeA, shapeB, cv2.CONTOURS_MATCH_I2, 1)
distances.append((ann['id'], distance))
except Exception as e:
logger.critical(f"Exception comparing {annotation.id} to {ann['id']}, points: {ann['points']}")
logger.exception(e)
distances = sorted(distances, key=lambda d: d[1])
for i in range(10):
similarAnnotation = storage.getAnnotationById(distances[i][0])
print(coco.cats[similarAnnotation.category_id])
dwg = similarAnnotation.asSvg(f'tmp/result_{i}.svg', square=True)
dwg.save()
if args.stickers:
grid = (3, 4) # items in the grid x,y
size = (105, 148) # in mm
sizeFactor = 5 # influences the size of the patterns
viewBoxSize = (size[0] * sizeFactor, size[1] * sizeFactor)
margin = 5
gridSize = (
int((viewBoxSize[0]-((grid[0]+1)*margin))/grid[0]),
int((viewBoxSize[1]-((grid[1]+1)*margin))/grid[1])
)
# see also textures.xml
textureIds = ["#siqwx","#wjnbs","#pnfez","#ejtxy","#obabs","#hehoj","#mrwjs","#ryjbw","#rkkau","#vbjcl","#zzehx","#mumke","#brhhk","#gujvh","#hfgqa","#lrbsh","#bndby","#bfnxk","#ydler","#pnxdr","#htqlj","#nunnt","#tidaw","#tcdum","#kwwja","#hgdkl","#nvkwz","#uzdqb","#fgshk","#vknil","#yeenr","#mslkw","#eibaw","#meama","#akuvz","#khkpp","#ibnow","#wivvx","#svksy","#xhmew","#jmiqu","#gfcer","#iueil","#iufvt","#ugkud","#dchzd","#nejks","#dqseb","#yhrwm","#bmiet","#qovkk","#hxoiq","#jfguh","#kbpkl","#ikarj","#nucap","#qfsqn","#bboqt","#pxkjn","#lbnx","#nxkmp","#snojb","#oioil","#hvldz","#qpscp","#oborh","#crobu","#ydhwn","#geanf","#sdfeo","#cgtma","#rjfrc","#uhcys","#lrgem","#osiho","#etssd","#esxcs","#hczhr","#nnhxw","#wrlbu"]
nr = 0
total_nr = len(coco.cats)
for category_id, cat in coco.cats.items():
nr+=1
filename = os.path.join(
args.stickers,
f"{category_id}_{cat['supercategory']}_{cat['name']}.svg")
dwg = svgwrite.Drawing(
filename,
size=(f'{size[0]}mm', f'{size[1]}mm'),
viewBox=f"0 0 {viewBoxSize[0]} {viewBoxSize[1]}"
)
annotations = storage.getRandomAnnotations(
limit = grid[0]*grid[1],
category_id = category_id
)
inkscape = Inkscape(dwg)
contourG = inkscape.layer(label='Snijlijnen')
drawingG = inkscape.layer(label='Shapes')
# dwg.add(svgwrite.container.Defs())
dwg.add(drawingG)
dwg.add(contourG)
font_size = 10
text = dwg.text(
f"{nr:02d}/{total_nr}",
insert=(margin, margin+font_size), font_size=font_size, fill='black'
)
drawingG.add(text)
text = dwg.text(
f"{category_id}. {cat['supercategory']} - {cat['name']}",
insert=(viewBoxSize[0]-margin, margin+font_size), font_size=font_size, fill='black',
style='text-anchor:end;')
drawingG.add(text)
text = dwg.text(
f"Common Objects In Context",
insert=(margin, viewBoxSize[1]-margin), font_size=font_size, fill='black',
)
drawingG.add(text)
text = dwg.text(
f"Plotting Data",
insert=(viewBoxSize[0]-margin, viewBoxSize[1]-margin), font_size=font_size, fill='black',
style='text-anchor:end;')
drawingG.add(text)
for i, annotation in enumerate(annotations):
normAnn = annotation.getNormalised(gridSize[0], gridSize[1])
translation = normAnn.getTranslationToCenter()
# print(translation)
pX = i%grid[0]
pY = int(i/grid[0])
posX = pX*gridSize[0] + (pX+1)*margin - translation[0]
posY = pY*gridSize[1] + (pY+1)*margin - translation[1]
# print(i, posX, posY, gridSize)
positionG = svgwrite.container.Group(transform=f'translate({posX}, {posY})')
normAnn.writeToDrawing(positionG, stroke='#2FEE2F', stroke_width='1pt', fill_opacity="0")
contourG.add(positionG)
position2G = svgwrite.container.Group(transform=f'translate({posX}, {posY})')
pattern_id = textureIds[category_id % len(textureIds)]
normAnn.writeToDrawing(position2G, fill=f'url({pattern_id})', stroke='blue', stroke_width='0')
drawingG.add(position2G)
xml = dwg.get_xml()
with open('textures.xml', 'r') as fp:
textureTree = ElementTree.fromstring(fp.read())
defsTree = xml.find('defs')
for pattern in textureTree:
defsTree.append(pattern)
xmlString = ElementTree.tostring(xml)
with open(filename, 'wb') as fp:
# print(xmlString)
fp.write(xmlString)
logger.info(f"Wrote to {filename}")