125 lines
4.2 KiB
Python
125 lines
4.2 KiB
Python
|
import tornado.ioloop
|
||
|
import tornado.web
|
||
|
import tornado.websocket
|
||
|
import argparse
|
||
|
import logging
|
||
|
import coloredlogs
|
||
|
from coco.storage import COCOStorage
|
||
|
import json
|
||
|
from urllib.parse import urlparse
|
||
|
|
||
|
logger = logging.getLogger('coco.server')
|
||
|
|
||
|
class JsonEncoder(json.JSONEncoder):
|
||
|
def default(self, obj):
|
||
|
method = getattr(obj, "forJson", None)
|
||
|
if callable(method ):
|
||
|
return obj.forJson()
|
||
|
# Let the base class default method raise the TypeError
|
||
|
return json.JSONEncoder.default(self, obj)
|
||
|
|
||
|
class RestHandler(tornado.web.RequestHandler):
|
||
|
def initialize(self, storage: COCOStorage):
|
||
|
self.storage = storage
|
||
|
self.set_header("Content-Type", "application/json")
|
||
|
|
||
|
def get(self, *params):
|
||
|
self.write(json.dumps(self.getData(*params), cls=JsonEncoder))
|
||
|
|
||
|
class CategoryHandler(RestHandler):
|
||
|
def getData(self):
|
||
|
return self.storage.getCategories()
|
||
|
|
||
|
class AnnotationHandler(RestHandler):
|
||
|
def getData(self):
|
||
|
# get specific annotation
|
||
|
annotation_id = self.get_argument('id', None)
|
||
|
annotation_id = None if not annotation_id else int(annotation_id)
|
||
|
|
||
|
# get by category id
|
||
|
category_id = self.get_argument('category', None)
|
||
|
category_id = None if not category_id else int(category_id)
|
||
|
|
||
|
normalise = self.get_argument('normalise', False)
|
||
|
normalise = int(normalise) if normalise is not False else False
|
||
|
|
||
|
# category_id = None if not category_id else int(category_id)
|
||
|
|
||
|
logger.debug(f'Get annotation id: {annotation_id}, category: {category_id}, normalised: {normalise}')
|
||
|
|
||
|
annotation = self.storage.getRandomAnnotation(annotation_id=annotation_id, category_id=category_id)
|
||
|
if normalise:
|
||
|
return annotation.getNormalised(normalise, normalise)
|
||
|
return annotation
|
||
|
|
||
|
|
||
|
class WebSocketHandler(tornado.websocket.WebSocketHandler):
|
||
|
CORS_ORIGINS = ['localhost', 'coco.local', 'r3.local']
|
||
|
|
||
|
def check_origin(self, origin):
|
||
|
parsed_origin = urlparse(origin)
|
||
|
# parsed_origin.netloc.lower() gives localhost:3333
|
||
|
valid = parsed_origin.hostname in self.CORS_ORIGINS
|
||
|
return valid
|
||
|
|
||
|
# the client connected
|
||
|
def open(self, p = None):
|
||
|
WebSocketHandler.connections.add(self)
|
||
|
logger.info("New client connected")
|
||
|
self.write_message("hello!")
|
||
|
|
||
|
# the client sent the message
|
||
|
def on_message(self, message):
|
||
|
logger.debug(f"recieve: {message}")
|
||
|
|
||
|
|
||
|
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
|
||
|
def set_extra_headers(self, path):
|
||
|
"""For subclass to add extra headers to the response"""
|
||
|
if path[-5:] == '.html':
|
||
|
self.set_header("Access-Control-Allow-Origin", "*")
|
||
|
|
||
|
def make_app(db_filename, debug):
|
||
|
storage = COCOStorage(db_filename)
|
||
|
|
||
|
return tornado.web.Application([
|
||
|
(r"/ws(.*)", WebSocketHandler),
|
||
|
(r"/categories.json", CategoryHandler, {'storage': storage}),
|
||
|
(r"/annotation.json", AnnotationHandler, {'storage': storage}),
|
||
|
(r"/(.*)", StaticFileWithHeaderHandler,
|
||
|
{"path": 'www', "default_filename": 'index.html'}),
|
||
|
], debug=debug)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
argParser = argparse.ArgumentParser(description='Server for COCO web interface')
|
||
|
argParser.add_argument(
|
||
|
'--port',
|
||
|
'-P',
|
||
|
type=int,
|
||
|
default=8888,
|
||
|
help='Port to listen on'
|
||
|
)
|
||
|
argParser.add_argument(
|
||
|
'--db',
|
||
|
type=str,
|
||
|
metavar='DATABASE',
|
||
|
required=True,
|
||
|
help='Database to serve from'
|
||
|
)
|
||
|
argParser.add_argument(
|
||
|
'--verbose',
|
||
|
'-v',
|
||
|
action='store_true',
|
||
|
help='Increase log level'
|
||
|
)
|
||
|
args = argParser.parse_args()
|
||
|
|
||
|
loglevel = logging.DEBUG if args.verbose else logging.INFO
|
||
|
coloredlogs.install(
|
||
|
level=loglevel,
|
||
|
fmt="%(asctime)s %(hostname)s %(name)s[%(process)d] %(levelname)s %(message)s"
|
||
|
)
|
||
|
|
||
|
app = make_app(args.db, debug=args.verbose )
|
||
|
app.listen(args.port)
|
||
|
tornado.ioloop.IOLoop.current().start()
|