199 lines
No EOL
6.9 KiB
Python
199 lines
No EOL
6.9 KiB
Python
import tornado.ioloop
|
|
import tornado.web
|
|
import tornado.websocket
|
|
import argparse
|
|
import logging
|
|
import coloredlogs
|
|
from coco.storage import COCOStorage
|
|
import json
|
|
from urllib.parse import urlparse
|
|
import uuid
|
|
import os
|
|
import glob
|
|
|
|
logger = logging.getLogger('coco.server')
|
|
|
|
class JsonEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
method = getattr(obj, "forJson", None)
|
|
if callable(method ):
|
|
return obj.forJson()
|
|
# Let the base class default method raise the TypeError
|
|
return json.JSONEncoder.default(self, obj)
|
|
|
|
class RestHandler(tornado.web.RequestHandler):
|
|
def initialize(self, storage: COCOStorage):
|
|
self.storage = storage
|
|
self.set_header("Content-Type", "application/json")
|
|
|
|
def get(self, *params):
|
|
self.write(json.dumps(self.getData(*params), cls=JsonEncoder))
|
|
|
|
def post(self, *params):
|
|
self.write(json.dumps(self.getData(*params), cls=JsonEncoder))
|
|
|
|
class CategoryHandler(RestHandler):
|
|
def getData(self):
|
|
return self.storage.getCategories()
|
|
|
|
class AnnotationHandler(RestHandler):
|
|
def getData(self):
|
|
# get specific annotation
|
|
annotation_id = self.get_argument('id', None)
|
|
annotation_id = None if not annotation_id else int(annotation_id)
|
|
|
|
# get by category id
|
|
category_id = self.get_argument('category', None)
|
|
category_id = None if not category_id else int(category_id)
|
|
|
|
normalise = self.get_argument('normalise', False)
|
|
normalise = int(normalise) if normalise is not False else False
|
|
|
|
# category_id = None if not category_id else int(category_id)
|
|
|
|
logger.debug(f'Get annotation id: {annotation_id}, category: {category_id}, normalised: {normalise}')
|
|
|
|
annotation = self.storage.getRandomAnnotation(annotation_id=annotation_id, category_id=category_id)
|
|
if normalise:
|
|
return annotation.getNormalised(normalise, normalise)
|
|
return annotation
|
|
|
|
|
|
class SaveHandler(RestHandler):
|
|
def getData(self):
|
|
"""
|
|
Save an SVG. Regenerate it on the server to prevent any maliscious input
|
|
"""
|
|
req = tornado.escape.json_decode(self.request.body)
|
|
scene = int(req['scene'])
|
|
annotations = []
|
|
|
|
with open('www/canvas_patterns.json') as fp:
|
|
patterns = json.load(fp)
|
|
svgGs = []
|
|
|
|
for annotation in req['annotations'][:100]: # max 200 annotations
|
|
annId = int(annotation['id'])
|
|
ann = self.storage.getAnnotationById(annId)
|
|
normalisedAnn = ann.getNormalised(100,100)
|
|
x = float(annotation['x'])
|
|
y = float(annotation['y'])
|
|
fill = patterns[str(ann.category_id)]
|
|
segments = []
|
|
|
|
textX = normalisedAnn.bbox[2]+5
|
|
textY = normalisedAnn.bbox[3]
|
|
|
|
cat = self.storage.getCategory(ann.category_id)
|
|
image = self.storage.getImage(ann.image_id)
|
|
|
|
for segment in normalisedAnn.segments:
|
|
d = segment.getD()
|
|
segments.append(f"""<path fill="{fill}" d="{d}"></path>""")
|
|
svgGs.append(f"""
|
|
<g data-id="{ann.id}" transform="translate({x},{y})">
|
|
<image href="{image ['coco_url']}"
|
|
width="{image['width']*normalisedAnn.scale}"
|
|
height="{image['height']*normalisedAnn.scale}"
|
|
x="{ann.bbox[0] * -1 * normalisedAnn.scale}"
|
|
y="{ann.bbox[1] * -1 * normalisedAnn.scale}"></image>
|
|
{"".join(segments)}
|
|
<text fill="white" font-size="30pt" font-family="sans-serif" x="{textX}" y="{textY}">{cat['name']}</text>
|
|
</g>
|
|
""")
|
|
annotations.append({'id': annId, 'x': x, 'y': y})
|
|
|
|
source = json.dumps({
|
|
'scene': scene,
|
|
'annotations': annotations
|
|
})
|
|
|
|
with open('www/canvas.svg') as fp:
|
|
svgContent = fp.read()
|
|
svgContent = svgContent.replace('{source}', json.dumps(source))\
|
|
.replace('</svg>', "".join(svgGs)+"</svg>")\
|
|
.replace('{scenenr}', str(scene))
|
|
|
|
|
|
saveId = uuid.uuid4().hex + '.svg'
|
|
filename = os.path.join('www/saved', saveId)
|
|
with open(filename, 'w') as fp:
|
|
fp.write(svgContent)
|
|
return {'submission':'/saved/'+saveId}
|
|
|
|
|
|
|
|
class SavedHandler(tornado.web.RequestHandler):
|
|
def initialize(self, storage: COCOStorage):
|
|
self.storage = storage
|
|
|
|
def get(self):
|
|
images = []
|
|
files = glob.glob("www/saved/*.svg")
|
|
files.sort(key=lambda f: -1 * os.path.getmtime(f))
|
|
|
|
for filename in files[:100]:
|
|
# latest 100 images only
|
|
with open(filename, 'r') as fp:
|
|
# remove first XML line:
|
|
contents = '\n'.join(fp.read().split('\n')[1:])
|
|
images.append(contents)
|
|
|
|
with open("www/saved.html") as fp:
|
|
template = fp.read()
|
|
|
|
template = template.replace("{images}", ''.join(images))
|
|
self.write(template)
|
|
|
|
|
|
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
|
|
def set_extra_headers(self, path):
|
|
"""For subclass to add extra headers to the response"""
|
|
if path[-5:] == '.html':
|
|
self.set_header("Access-Control-Allow-Origin", "*")
|
|
|
|
def make_app(db_filename, debug):
|
|
storage = COCOStorage(db_filename)
|
|
|
|
return tornado.web.Application([
|
|
(r"/categories.json", CategoryHandler, {'storage': storage}),
|
|
(r"/annotation.json", AnnotationHandler, {'storage': storage}),
|
|
(r"/save", SaveHandler, {'storage': storage}),
|
|
(r"/saved", SavedHandler, {'storage': storage}),
|
|
(r"/(.*)", StaticFileWithHeaderHandler,
|
|
{"path": 'www', "default_filename": 'index.html'}),
|
|
], debug=debug)
|
|
|
|
if __name__ == "__main__":
|
|
argParser = argparse.ArgumentParser(description='Server for COCO web interface')
|
|
argParser.add_argument(
|
|
'--port',
|
|
'-P',
|
|
type=int,
|
|
default=8888,
|
|
help='Port to listen on'
|
|
)
|
|
argParser.add_argument(
|
|
'--db',
|
|
type=str,
|
|
metavar='DATABASE',
|
|
required=True,
|
|
help='Database to serve from'
|
|
)
|
|
argParser.add_argument(
|
|
'--verbose',
|
|
'-v',
|
|
action='store_true',
|
|
help='Increase log level'
|
|
)
|
|
args = argParser.parse_args()
|
|
|
|
loglevel = logging.DEBUG if args.verbose else logging.INFO
|
|
coloredlogs.install(
|
|
level=loglevel,
|
|
fmt="%(asctime)s %(hostname)s %(name)s[%(process)d] %(levelname)s %(message)s"
|
|
)
|
|
|
|
app = make_app(args.db, debug=args.verbose )
|
|
app.listen(args.port)
|
|
tornado.ioloop.IOLoop.current().start() |