import logging import os import sqlite3 import pycocotools.coco import ast import copy import svgwrite logger = logging.getLogger('coco.storage') class Annotation: def __init__(self, result, storage): self.storage = storage self.id = result['id'] self.image_id = result['image_id'] self.category_id = result['category_id'] self.iscrowd = bool(result['iscrowd']) self.area = result['area'] self.bbox = [result['bbox_left'], result['bbox_top'], result['bbox_width'], result['bbox_height']] self.segments = self.fetchSegments() self.is_normalised = False if type(result['zerkine_moment']) is list: self.zerkine_moment = result['zerkine_moment'] # when normalising, this is already there else: self.zerkine_moment = self.parseZerkineFromDB(result['zerkine_moment']) if result['zerkine_moment'] else None @classmethod def parseZerkineFromDB(cls, r): z = r.split(' ') return [float(i) for i in z] def fetchSegments(self): try: cur = self.storage.con.cursor() cur.execute("SELECT * FROM segments WHERE annotation_id = :id AND points != 'ount' AND points != 'iz'", {'id': self.id}) segments = [] for row in cur: segments.append(Segment(row)) except Exception as e: logger.critical(f"Invalid segment for annotation {self.id}") logger.exception(e) raise(e) return segments def getNormalised(self, width, height) -> 'Annotation': ''' center segments in boundig box with given width and height, and on point 0,0 ''' scale = min(width/self.bbox[2], height/self.bbox[3]) logger.debug(f"Normalise from bbox: {self.bbox}") new_width = self.bbox[2] * scale new_height = self.bbox[3] * scale dx = (width - new_width) / 2 dy = (height - new_height) / 2 data = self.forJson() data['bbox_left'] = 0 data['bbox_top'] = 0 data['bbox_width'] = new_width data['bbox_height'] = new_height newAnn = Annotation(data, self.storage) newAnn.is_normalised = True newAnn.bbox_original = self.bbox newAnn.scale = scale for i, segment in enumerate(newAnn.segments): newAnn.segments[i].points = [[ (p[0]-self.bbox[0]) * scale, (p[1]-self.bbox[1]) * scale ] for p in segment.points] return newAnn def forJson(self): data = self.__dict__.copy() del data['storage'] data['image'] = self.storage.getImage(data['image_id']) return data def writeToDrawing(self, dwg, **pathSpecs): for segment in self.segments: if len(pathSpecs) == 0: pathSpecs['fill'] = 'white' dwg.add(svgwrite.path.Path(segment.getD(), class_=f"cat_{self.category_id}", **pathSpecs)) def getTranslationToCenter(self): dimensions = (self.bbox[2], self.bbox[3]) targetSize = max(dimensions) dx = (dimensions[0] - targetSize)/2 dy = (dimensions[1] - targetSize)/2 return (dx, dy) def asSvg(self, filename, square=False, bg=None): dimensions = (self.bbox[2], self.bbox[3]) viewbox = copy.copy(self.bbox) if square: targetSize = max(dimensions) dx = (dimensions[0] - targetSize)/2 dy = (dimensions[1] - targetSize)/2 viewbox[2] = targetSize viewbox[3] = targetSize dimensions = (targetSize, targetSize) viewbox[0] += dx viewbox[1] += dy dwg = svgwrite.Drawing( filename, size=dimensions, viewBox=" ".join([str(s) for s in viewbox]) ) if bg: dwg.add(dwg.rect( (viewbox[0],viewbox[1]), (viewbox[2],viewbox[3]), fill=bg)) self.writeToDrawing(dwg) return dwg class Segment(): def __init__(self, result): try: self.points = self.asCoordinates(ast.literal_eval('['+result['points']+']')) except Exception as e: logger.critical(f"Exception loading segment for {result} {result['points']}") raise @classmethod def asCoordinates(cls, pointList): points = [] r = len(pointList) / 2 for i in range(int(r)): points.append([ pointList[(i)*2], pointList[(i)*2+1] ]) return points def getD(self): start = self.points[0] d = f'M{start[0]:.4f} {start[1]:.4f} L' for i in range(1, len(self.points)): p = self.points[i] d += f' {p[0]:.4f} {p[1]:.4f}' d += " Z" # segments are always closed return d def forJson(self): return self.points class COCOStorage: def __init__(self, filename): self.logger = logging.getLogger('coco.storage') self.filename = filename if not os.path.exists(self.filename): con = sqlite3.connect(self.filename) cur = con.cursor() d = os.path.dirname(os.path.realpath(__file__)) with open(os.path.join(d,'coco.sql'), 'r') as fp: cur.executescript(fp.read()) con.close() self.con = sqlite3.connect(self.filename) self.con.row_factory = sqlite3.Row def propagateFromAnnotations(self, filename): self.logger.info(f"Load {filename}") coco = pycocotools.coco.COCO(filename) self.logger.info("Create categories") cur = self.con.cursor() cur.executemany('INSERT OR IGNORE INTO categories(id, supercategory, name) VALUES (:id, :supercategory, :name)', coco.cats.values()) self.con.commit() self.logger.info("Images...") cur.executemany(''' INSERT OR IGNORE INTO images(id, flickr_url, coco_url, width, height, date_captured) VALUES (:id, :flickr_url, :coco_url, :width, :height, :date_captured) ''', coco.imgs.values()) self.con.commit() self.logger.info("Annotations...") def annotation_generator(): for c in coco.anns.values(): ann = c.copy() ann['bbox_top'] = ann['bbox'][1] ann['bbox_left'] = ann['bbox'][0] ann['bbox_width'] = ann['bbox'][2] ann['bbox_height'] = ann['bbox'][3] yield ann cur.executemany(''' INSERT OR IGNORE INTO annotations(id, image_id, category_id, iscrowd, area, bbox_top, bbox_left, bbox_width, bbox_height) VALUES (:id, :image_id, :category_id, :iscrowd, :area, :bbox_top, :bbox_left, :bbox_width, :bbox_height) ''', annotation_generator()) self.con.commit() self.logger.info("Segments...") def segment_generator(): for ann in coco.anns.values(): for i, seg in enumerate(ann['segmentation']): yield { 'id': ann['id']*10 + i, # create a uniqe segment id, supports max 10 segments per annotation 'annotation_id': ann['id'], 'points': str(seg)[1:-1], } cur.executemany(''' INSERT OR IGNORE INTO segments(id, annotation_id, points) VALUES (:id, :annotation_id, :points) ''', segment_generator()) self.con.commit() self.logger.info("Done...") def getCategories(self): if not hasattr(self, 'categories'): cur = self.con.cursor() cur.execute("SELECT * FROM categories ORDER BY id") self.categories = [dict(cat) for cat in cur] return self.categories def getCategory(self, cid): cats = self.getCategories() cat = [c for c in cats if c['id'] == cid] if not len(cat): return None return cat[0] def getImage(self, image_id: int): cur = self.con.cursor() cur.execute(f"SELECT * FROM images WHERE id = ? LIMIT 1", (image_id,)) img = cur.fetchone() return dict(img) def getAnnotationWithoutZerkine(self): cur = self.con.cursor() # annotation 918 and 2206849 have 0 height. Crashing the script... exclude them cur.execute(f"SELECT * FROM annotations WHERE zerkine_moment IS NULL AND area > 0 LIMIT 1") ann = cur.fetchone() if ann: return Annotation(ann, self) else: return None def countAnnotationsWithoutZerkine(self): cur = self.con.cursor() cur.execute(f"SELECT count(id) FROM annotations WHERE zerkine_moment IS NULL AND area > 0") return int(cur.fetchone()[0]) def storeZerkineForAnnotation(self, annotation, moments, delayCommit = False): m = ' '.join([str(m) for m in moments]) cur = self.con.cursor() cur.execute( "UPDATE annotations SET zerkine_moment = :z WHERE id = :id", {'z': m, 'id': annotation.id} ) if not delayCommit: self.con.commit() return True def getZerkines(self): cur = self.con.cursor() cur.execute(f"SELECT id, zerkine_moment FROM annotations WHERE zerkine_moment IS NOT NULL") return cur.fetchall() def getAllAnnotationPoints(self): cur = self.con.cursor() cur.execute(f"SELECT annotations.id, points FROM annotations INNER JOIN segments ON segments.annotation_id = annotations.id WHERE area > 0") return cur.fetchall() def getAnnotationById(self, annotation_id = None, withZerkine = False): if annotation_id == -1: annotation_id = None return self.getRandomAnnotation(annotation_id = annotation_id, withZerkine = withZerkine) def getRandomAnnotation(self, annotation_id = None, category_id = None, withZerkine = False): result = self.getRandomAnnotations(annotation_id, category_id, withZerkine, limit=1) return result[0] if len(result) else None def getRandomAnnotations(self, annotation_id = None, category_id = None, withZerkine = False, limit=None): cur = self.con.cursor() where = "" params = [] if annotation_id: where = "id = ?" params.append(annotation_id) elif category_id: where = "category_id = ?" params.append(category_id) else: where = "1=1" if withZerkine: where += " AND zerkine_moment IS NOT NULL" sqlLimit = "" if limit: sqlLimit = f"LIMIT {int(limit)}" cur.execute(f"SELECT * FROM annotations WHERE {where} ORDER BY RANDOM() {sqlLimit}", tuple(params)) results = [] for ann in cur: results.append(Annotation(ann, self)) return results # ann = cur.fetchall() # # if ann: # return Annotation(ann, self) # else: # return None