import tornado.ioloop import tornado.web import tornado.websocket import argparse import logging import coloredlogs from coco.storage import COCOStorage import json from urllib.parse import urlparse logger = logging.getLogger('coco.server') class JsonEncoder(json.JSONEncoder): def default(self, obj): method = getattr(obj, "forJson", None) if callable(method ): return obj.forJson() # Let the base class default method raise the TypeError return json.JSONEncoder.default(self, obj) class RestHandler(tornado.web.RequestHandler): def initialize(self, storage: COCOStorage): self.storage = storage self.set_header("Content-Type", "application/json") def get(self, *params): self.write(json.dumps(self.getData(*params), cls=JsonEncoder)) class CategoryHandler(RestHandler): def getData(self): return self.storage.getCategories() class AnnotationHandler(RestHandler): def getData(self): # get specific annotation annotation_id = self.get_argument('id', None) annotation_id = None if not annotation_id else int(annotation_id) # get by category id category_id = self.get_argument('category', None) category_id = None if not category_id else int(category_id) normalise = self.get_argument('normalise', False) normalise = int(normalise) if normalise is not False else False # category_id = None if not category_id else int(category_id) logger.debug(f'Get annotation id: {annotation_id}, category: {category_id}, normalised: {normalise}') annotation = self.storage.getRandomAnnotation(annotation_id=annotation_id, category_id=category_id) if normalise: return annotation.getNormalised(normalise, normalise) return annotation class WebSocketHandler(tornado.websocket.WebSocketHandler): CORS_ORIGINS = ['localhost', 'coco.local', 'r3.local'] def check_origin(self, origin): parsed_origin = urlparse(origin) # parsed_origin.netloc.lower() gives localhost:3333 valid = parsed_origin.hostname in self.CORS_ORIGINS return valid # the client connected def open(self, p = None): WebSocketHandler.connections.add(self) logger.info("New client connected") self.write_message("hello!") # the client sent the message def on_message(self, message): logger.debug(f"recieve: {message}") class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler): def set_extra_headers(self, path): """For subclass to add extra headers to the response""" if path[-5:] == '.html': self.set_header("Access-Control-Allow-Origin", "*") def make_app(db_filename, debug): storage = COCOStorage(db_filename) return tornado.web.Application([ (r"/ws(.*)", WebSocketHandler), (r"/categories.json", CategoryHandler, {'storage': storage}), (r"/annotation.json", AnnotationHandler, {'storage': storage}), (r"/(.*)", StaticFileWithHeaderHandler, {"path": 'www', "default_filename": 'index.html'}), ], debug=debug) if __name__ == "__main__": argParser = argparse.ArgumentParser(description='Server for COCO web interface') argParser.add_argument( '--port', '-P', type=int, default=8888, help='Port to listen on' ) argParser.add_argument( '--db', type=str, metavar='DATABASE', required=True, help='Database to serve from' ) argParser.add_argument( '--verbose', '-v', action='store_true', help='Increase log level' ) args = argParser.parse_args() loglevel = logging.DEBUG if args.verbose else logging.INFO coloredlogs.install( level=loglevel, fmt="%(asctime)s %(hostname)s %(name)s[%(process)d] %(levelname)s %(message)s" ) app = make_app(args.db, debug=args.verbose ) app.listen(args.port) tornado.ioloop.IOLoop.current().start()