
12 changed files with 1606 additions and 34 deletions
@ -0,0 +1,320 @@
@@ -0,0 +1,320 @@
|
||||
import logging |
||||
import os |
||||
import sqlite3 |
||||
import pycocotools.coco |
||||
import ast |
||||
import copy |
||||
import svgwrite |
||||
|
||||
logger = logging.getLogger('coco.storage') |
||||
|
||||
class Annotation: |
||||
def __init__(self, result, storage): |
||||
self.storage = storage |
||||
self.id = result['id'] |
||||
self.image_id = result['image_id'] |
||||
self.category_id = result['category_id'] |
||||
self.iscrowd = bool(result['iscrowd']) |
||||
self.area = result['area'] |
||||
self.bbox = [result['bbox_left'], result['bbox_top'], result['bbox_width'], result['bbox_height']] |
||||
self.segments = self.fetchSegments() |
||||
self.is_normalised = False |
||||
if type(result['zerkine_moment']) is list: |
||||
self.zerkine_moment = result['zerkine_moment'] # when normalising, this is already there |
||||
else: |
||||
self.zerkine_moment = self.parseZerkineFromDB(result['zerkine_moment']) if result['zerkine_moment'] else None |
||||
|
||||
@classmethod |
||||
def parseZerkineFromDB(cls, r): |
||||
z = r.split(' ') |
||||
return [float(i) for i in z] |
||||
|
||||
def fetchSegments(self): |
||||
try: |
||||
cur = self.storage.con.cursor() |
||||
cur.execute("SELECT * FROM segments WHERE annotation_id = :id AND points != 'ount' AND points != 'iz'", {'id': self.id}) |
||||
segments = [] |
||||
for row in cur: |
||||
segments.append(Segment(row)) |
||||
except Exception as e: |
||||
logger.critical(f"Invalid segment for annotation {self.id}") |
||||
logger.exception(e) |
||||
raise(e) |
||||
return segments |
||||
|
||||
def getNormalised(self, width, height) -> 'Annotation': |
||||
''' |
||||
center segments in boundig box with given width and height, and on point 0,0 |
||||
''' |
||||
scale = min(width/self.bbox[2], height/self.bbox[3]) |
||||
logger.debug(f"Normalise from bbox: {self.bbox}") |
||||
new_width = self.bbox[2] * scale |
||||
new_height = self.bbox[3] * scale |
||||
|
||||
dx = (width - new_width) / 2 |
||||
dy = (height - new_height) / 2 |
||||
|
||||
data = self.forJson() |
||||
data['bbox_left'] = 0 |
||||
data['bbox_top'] = 0 |
||||
data['bbox_width'] = new_width |
||||
data['bbox_height'] = new_height |
||||
newAnn = Annotation(data, self.storage) |
||||
newAnn.is_normalised = True |
||||
newAnn.bbox_original = self.bbox |
||||
newAnn.scale = scale |
||||
|
||||
for i, segment in enumerate(newAnn.segments): |
||||
newAnn.segments[i].points = [[ |
||||
(p[0]-self.bbox[0]) * scale, |
||||
(p[1]-self.bbox[1]) * scale |
||||
] for p in segment.points] |
||||
|
||||
|
||||
return newAnn |
||||
|
||||
def forJson(self): |
||||
data = self.__dict__.copy() |
||||
del data['storage'] |
||||
data['image'] = self.storage.getImage(data['image_id']) |
||||
return data |
||||
|
||||
def writeToDrawing(self, dwg, **pathSpecs): |
||||
for segment in self.segments: |
||||
if len(pathSpecs) == 0: |
||||
pathSpecs['fill'] = 'white' |
||||
dwg.add(svgwrite.path.Path(segment.getD(), class_=f"cat_{self.category_id}", **pathSpecs)) |
||||
|
||||
def getTranslationToCenter(self): |
||||
dimensions = (self.bbox[2], self.bbox[3]) |
||||
targetSize = max(dimensions) |
||||
dx = (dimensions[0] - targetSize)/2 |
||||
dy = (dimensions[1] - targetSize)/2 |
||||
return (dx, dy) |
||||
|
||||
def asSvg(self, filename, square=False, bg=None): |
||||
dimensions = (self.bbox[2], self.bbox[3]) |
||||
viewbox = copy.copy(self.bbox) |
||||
if square: |
||||
targetSize = max(dimensions) |
||||
dx = (dimensions[0] - targetSize)/2 |
||||
dy = (dimensions[1] - targetSize)/2 |
||||
viewbox[2] = targetSize |
||||
viewbox[3] = targetSize |
||||
dimensions = (targetSize, targetSize) |
||||
viewbox[0] += dx |
||||
viewbox[1] += dy |
||||
dwg = svgwrite.Drawing( |
||||
filename, |
||||
size=dimensions, |
||||
viewBox=" ".join([str(s) for s in viewbox]) |
||||
) |
||||
|
||||
if bg: |
||||
dwg.add(dwg.rect( |
||||
(viewbox[0],viewbox[1]), |
||||
(viewbox[2],viewbox[3]), |
||||
fill=bg)) |
||||
self.writeToDrawing(dwg) |
||||
return dwg |
||||
|
||||
class Segment(): |
||||
def __init__(self, result): |
||||
try: |
||||
self.points = self.asCoordinates(ast.literal_eval('['+result['points']+']')) |
||||
except Exception as e: |
||||
logger.critical(f"Exception loading segment for {result} {result['points']}") |
||||
raise |
||||
|
||||
@classmethod |
||||
def asCoordinates(cls, pointList): |
||||
points = [] |
||||
|
||||
r = len(pointList) / 2 |
||||
for i in range(int(r)): |
||||
points.append([ |
||||
pointList[(i)*2], |
||||
pointList[(i)*2+1] |
||||
]) |
||||
return points |
||||
|
||||
def getD(self): |
||||
start = self.points[0] |
||||
d = f'M{start[0]:.4f} {start[1]:.4f} L' |
||||
for i in range(1, len(self.points)): |
||||
p = self.points[i] |
||||
d += f' {p[0]:.4f} {p[1]:.4f}' |
||||
d += " Z" # segments are always closed |
||||
return d |
||||
|
||||
def forJson(self): |
||||
return self.points |
||||
|
||||
class COCOStorage: |
||||
def __init__(self, filename): |
||||
self.logger = logging.getLogger('coco.storage') |
||||
self.filename = filename |
||||
if not os.path.exists(self.filename): |
||||
con = sqlite3.connect(self.filename) |
||||
cur = con.cursor() |
||||
with open('../coco.sql', 'r') as fp: |
||||
cur.executescript(fp.read()) |
||||
con.close() |
||||
|
||||
self.con = sqlite3.connect(self.filename) |
||||
self.con.row_factory = sqlite3.Row |
||||
|
||||
def propagateFromAnnotations(self, filename): |
||||
self.logger.info(f"Load {filename}") |
||||
coco = pycocotools.coco.COCO(filename) |
||||
|
||||
self.logger.info("Create categories") |
||||
cur = self.con.cursor() |
||||
cur.executemany('INSERT OR IGNORE INTO categories(id, supercategory, name) VALUES (:id, :supercategory, :name)', coco.cats.values()) |
||||
self.con.commit() |
||||
|
||||
self.logger.info("Images...") |
||||
cur.executemany(''' |
||||
INSERT OR IGNORE INTO images(id, flickr_url, coco_url, width, height, date_captured) |
||||
VALUES (:id, :flickr_url, :coco_url, :width, :height, :date_captured) |
||||
''', coco.imgs.values()) |
||||
self.con.commit() |
||||
|
||||
self.logger.info("Annotations...") |
||||
|
||||
|
||||
def annotation_generator(): |
||||
for c in coco.anns.values(): |
||||
ann = c.copy() |
||||
ann['bbox_top'] = ann['bbox'][1] |
||||
ann['bbox_left'] = ann['bbox'][0] |
||||
ann['bbox_width'] = ann['bbox'][2] |
||||
ann['bbox_height'] = ann['bbox'][3] |
||||
yield ann |
||||
|
||||
cur.executemany(''' |
||||
INSERT OR IGNORE INTO annotations(id, image_id, category_id, iscrowd, area, bbox_top, bbox_left, bbox_width, bbox_height) |
||||
VALUES (:id, :image_id, :category_id, :iscrowd, :area, :bbox_top, :bbox_left, :bbox_width, :bbox_height) |
||||
''', annotation_generator()) |
||||
self.con.commit() |
||||
|
||||
|
||||
self.logger.info("Segments...") |
||||
|
||||
def segment_generator(): |
||||
for ann in coco.anns.values(): |
||||
for i, seg in enumerate(ann['segmentation']): |
||||
yield { |
||||
'id': ann['id']*10 + i, # create a uniqe segment id, supports max 10 segments per annotation |
||||
'annotation_id': ann['id'], |
||||
'points': str(seg)[1:-1], |
||||
} |
||||
|
||||
cur.executemany(''' |
||||
INSERT OR IGNORE INTO segments(id, annotation_id, points) |
||||
VALUES (:id, :annotation_id, :points) |
||||
''', segment_generator()) |
||||
self.con.commit() |
||||
|
||||
|
||||
self.logger.info("Done...") |
||||
|
||||
def getCategories(self): |
||||
if not hasattr(self, 'categories'): |
||||
cur = self.con.cursor() |
||||
cur.execute("SELECT * FROM categories ORDER BY id") |
||||
self.categories = [dict(cat) for cat in cur] |
||||
return self.categories |
||||
|
||||
def getCategory(self, cid): |
||||
cats = self.getCategories() |
||||
cat = [c for c in cats if c['id'] == cid] |
||||
if not len(cat): |
||||
return None |
||||
return cat[0] |
||||
|
||||
def getImage(self, image_id: int): |
||||
cur = self.con.cursor() |
||||
cur.execute(f"SELECT * FROM images WHERE id = ? LIMIT 1", (image_id,)) |
||||
img = cur.fetchone() |
||||
return dict(img) |
||||
|
||||
def getAnnotationWithoutZerkine(self): |
||||
cur = self.con.cursor() |
||||
# annotation 918 and 2206849 have 0 height. Crashing the script... exclude them |
||||
cur.execute(f"SELECT * FROM annotations WHERE zerkine_moment IS NULL AND area > 0 LIMIT 1") |
||||
ann = cur.fetchone() |
||||
if ann: |
||||
return Annotation(ann, self) |
||||
else: |
||||
return None |
||||
|
||||
def countAnnotationsWithoutZerkine(self): |
||||
cur = self.con.cursor() |
||||
|
||||
cur.execute(f"SELECT count(id) FROM annotations WHERE zerkine_moment IS NULL AND area > 0") |
||||
return int(cur.fetchone()[0]) |
||||
|
||||
def storeZerkineForAnnotation(self, annotation, moments, delayCommit = False): |
||||
m = ' '.join([str(m) for m in moments]) |
||||
cur = self.con.cursor() |
||||
|
||||
cur.execute( |
||||
"UPDATE annotations SET zerkine_moment = :z WHERE id = :id", |
||||
{'z': m, 'id': annotation.id} |
||||
) |
||||
if not delayCommit: |
||||
self.con.commit() |
||||
return True |
||||
|
||||
def getZerkines(self): |
||||
cur = self.con.cursor() |
||||
cur.execute(f"SELECT id, zerkine_moment FROM annotations WHERE zerkine_moment IS NOT NULL") |
||||
return cur.fetchall() |
||||
|
||||
def getAllAnnotationPoints(self): |
||||
cur = self.con.cursor() |
||||
cur.execute(f"SELECT annotations.id, points FROM annotations INNER JOIN segments ON segments.annotation_id = annotations.id WHERE area > 0") |
||||
return cur.fetchall() |
||||
|
||||
def getAnnotationById(self, annotation_id = None, withZerkine = False): |
||||
if annotation_id == -1: |
||||
annotation_id = None |
||||
return self.getRandomAnnotation(annotation_id = annotation_id, withZerkine = withZerkine) |
||||
|
||||
def getRandomAnnotation(self, annotation_id = None, category_id = None, withZerkine = False): |
||||
result = self.getRandomAnnotations(annotation_id, category_id, withZerkine, limit=1) |
||||
return result[0] if len(result) else None |
||||
|
||||
def getRandomAnnotations(self, annotation_id = None, category_id = None, withZerkine = False, limit=None): |
||||
cur = self.con.cursor() |
||||
where = "" |
||||
params = [] |
||||
if annotation_id: |
||||
where = "id = ?" |
||||
params.append(annotation_id) |
||||
elif category_id: |
||||
where = "category_id = ?" |
||||
params.append(category_id) |
||||
else: |
||||
where = "1=1" |
||||
|
||||
if withZerkine: |
||||
where += " AND zerkine_moment IS NOT NULL" |
||||
|
||||
sqlLimit = "" |
||||
if limit: |
||||
sqlLimit = f"LIMIT {int(limit)}" |
||||
|
||||
cur.execute(f"SELECT * FROM annotations WHERE {where} ORDER BY RANDOM() {sqlLimit}", tuple(params)) |
||||
results = [] |
||||
for ann in cur: |
||||
results.append(Annotation(ann, self)) |
||||
return results |
||||
# ann = cur.fetchall() |
||||
# |
||||
# if ann: |
||||
# return Annotation(ann, self) |
||||
# else: |
||||
# return None |
||||
|
@ -0,0 +1,81 @@
@@ -0,0 +1,81 @@
|
||||
import pycocotools.coco |
||||
import argparse |
||||
import logging |
||||
import tqdm |
||||
import urllib.request |
||||
import os |
||||
import svgwrite |
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO) |
||||
logger = logging.getLogger("coco") |
||||
|
||||
argParser = argparse.ArgumentParser(description='Create shape SVG\'s') |
||||
argParser.add_argument( |
||||
'--annotations', |
||||
type=str, |
||||
default='../../datasets/COCO/annotations/instances_train2017.json' |
||||
) |
||||
argParser.add_argument( |
||||
'--output', |
||||
type=str, |
||||
help='Output directory' |
||||
) |
||||
argParser.add_argument( |
||||
'--category_id', |
||||
type=int, |
||||
help='Category id' |
||||
) |
||||
args = argParser.parse_args() |
||||
|
||||
|
||||
logger.info(f"Load {args.annotations}") |
||||
coco = pycocotools.coco.COCO(args.annotations) |
||||
|
||||
def createSVGForImage(img_id, filename): |
||||
dimensions = (coco.imgs[img_id]['width'], coco.imgs[img_id]['height']) |
||||
dwg = svgwrite.Drawing(filename, size=dimensions) |
||||
for i, annotation in enumerate(coco.imgToAnns[img_id]): |
||||
if 'counts' in annotation['segmentation']: |
||||
# skip RLE masks |
||||
continue; |
||||
|
||||
addAnnotationToSVG(dwg, annotation) |
||||
dwg.save() |
||||
|
||||
def addAnnotationToSVG(dwg, annotation): |
||||
segmentation = annotation['segmentation'] |
||||
|
||||
|
||||
p = "" |
||||
for shape in segmentation: |
||||
# segmentation can have multiple shapes |
||||
p = f"M{shape[0]},{shape[1]} L " |
||||
r = len(shape) / 2 - 1 |
||||
for i in range(int(r)): |
||||
p += f"{shape[(i+1)*2]},{shape[(i+1)*2+1]} " |
||||
|
||||
cat = coco.cats[annotation['category_id']] |
||||
shape_classes = ' '.join([ |
||||
f"super_{cat['supercategory']}", |
||||
f"cat_{cat['name']}" |
||||
]) |
||||
dwg.add(dwg.path(p, class_=shape_classes)) |
||||
|
||||
category = coco.cats[args.category_id] |
||||
logger.info(f"Using {len(coco.cats)} categories") |
||||
logger.debug(coco.cats) |
||||
|
||||
logger.info(f"Limit to: {category}") |
||||
|
||||
|
||||
dirname = os.path.join(args.output, category['name']) |
||||
|
||||
if not os.path.exists(dirname): |
||||
os.mkdir(dirname) |
||||
|
||||
for img_id in tqdm.tqdm(coco.catToImgs[args.category_id]): |
||||
fn = os.path.join(dirname, f"{img_id}.svg") |
||||
createSVGForImage(img_id, fn) |
||||
|
||||
logger.info("Done") |
After Width: | Height: | Size: 70 KiB |
After Width: | Height: | Size: 622 KiB |
After Width: | Height: | Size: 47 KiB |
After Width: | Height: | Size: 411 KiB |
After Width: | Height: | Size: 25 KiB |
After Width: | Height: | Size: 1.1 MiB |
Loading…
Reference in new issue