coco/server.py

125 lines
4.2 KiB
Python

import tornado.ioloop
import tornado.web
import tornado.websocket
import argparse
import logging
import coloredlogs
from coco.storage import COCOStorage
import json
from urllib.parse import urlparse
logger = logging.getLogger('coco.server')
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
method = getattr(obj, "forJson", None)
if callable(method ):
return obj.forJson()
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)
class RestHandler(tornado.web.RequestHandler):
def initialize(self, storage: COCOStorage):
self.storage = storage
self.set_header("Content-Type", "application/json")
def get(self, *params):
self.write(json.dumps(self.getData(*params), cls=JsonEncoder))
class CategoryHandler(RestHandler):
def getData(self):
return self.storage.getCategories()
class AnnotationHandler(RestHandler):
def getData(self):
# get specific annotation
annotation_id = self.get_argument('id', None)
annotation_id = None if not annotation_id else int(annotation_id)
# get by category id
category_id = self.get_argument('category', None)
category_id = None if not category_id else int(category_id)
normalise = self.get_argument('normalise', False)
normalise = int(normalise) if normalise is not False else False
# category_id = None if not category_id else int(category_id)
logger.debug(f'Get annotation id: {annotation_id}, category: {category_id}, normalised: {normalise}')
annotation = self.storage.getRandomAnnotation(annotation_id=annotation_id, category_id=category_id)
if normalise:
return annotation.getNormalised(normalise, normalise)
return annotation
class WebSocketHandler(tornado.websocket.WebSocketHandler):
CORS_ORIGINS = ['localhost', 'coco.local', 'r3.local']
def check_origin(self, origin):
parsed_origin = urlparse(origin)
# parsed_origin.netloc.lower() gives localhost:3333
valid = parsed_origin.hostname in self.CORS_ORIGINS
return valid
# the client connected
def open(self, p = None):
WebSocketHandler.connections.add(self)
logger.info("New client connected")
self.write_message("hello!")
# the client sent the message
def on_message(self, message):
logger.debug(f"recieve: {message}")
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
def set_extra_headers(self, path):
"""For subclass to add extra headers to the response"""
if path[-5:] == '.html':
self.set_header("Access-Control-Allow-Origin", "*")
def make_app(db_filename, debug):
storage = COCOStorage(db_filename)
return tornado.web.Application([
(r"/ws(.*)", WebSocketHandler),
(r"/categories.json", CategoryHandler, {'storage': storage}),
(r"/annotation.json", AnnotationHandler, {'storage': storage}),
(r"/(.*)", StaticFileWithHeaderHandler,
{"path": 'www', "default_filename": 'index.html'}),
], debug=debug)
if __name__ == "__main__":
argParser = argparse.ArgumentParser(description='Server for COCO web interface')
argParser.add_argument(
'--port',
'-P',
type=int,
default=8888,
help='Port to listen on'
)
argParser.add_argument(
'--db',
type=str,
metavar='DATABASE',
required=True,
help='Database to serve from'
)
argParser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Increase log level'
)
args = argParser.parse_args()
loglevel = logging.DEBUG if args.verbose else logging.INFO
coloredlogs.install(
level=loglevel,
fmt="%(asctime)s %(hostname)s %(name)s[%(process)d] %(levelname)s %(message)s"
)
app = make_app(args.db, debug=args.verbose )
app.listen(args.port)
tornado.ioloop.IOLoop.current().start()