import json import logging import os import tornado.ioloop import tornado.web import tornado.websocket from urllib.parse import urlparse import uuid from threading import Thread, Event from queue import Queue, Empty import asyncio from sorteerhoed import HITStore from sorteerhoed.Signal import Signal import httpagentparser import geoip2.database logger = logging.getLogger("sorteerhoed").getChild("webserver") class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler): def set_extra_headers(self, path): """For subclass to add extra headers to the response""" if path[-5:] == '.html': self.set_header("Access-Control-Allow-Origin", "*") if path[-4:] == '.svg': self.set_header("Content-Type", "image/svg+xml") class WebSocketHandler(tornado.websocket.WebSocketHandler): CORS_ORIGINS = ['localhost', '.mturk.com', 'here.rubenvandeven.com'] connections = set() def initialize(self, config, plotterQ: Queue, eventQ: Queue, store: HITStore, geoip_reader: geoip2.database.Reader): self.config = config self.plotterQ = plotterQ self.eventQ = eventQ self.store = store self.geoip_reader = geoip_reader def check_origin(self, origin): parsed_origin = urlparse(origin) # parsed_origin.netloc.lower() gives localhost:3333 valid = any([parsed_origin.hostname.endswith(origin) for origin in self.CORS_ORIGINS]) return valid # the client connected def open(self, p = None): self.__class__.connections.add(self) hit_id = self.get_query_argument('id') self.hit = self.store.getHitById(hit_id) if self.hit.submit_hit_at: raise Exception("Opening websocket for already submitted hit") logger.info(f"New client connected: {self.request.remote_ip} for {self.hit.id}/{self.hit.hit_id}") self.eventQ.put(Signal('hit.info', dict(hit_id=self.hit.id, ip=self.request.remote_ip))) self.strokes = [] ua = self.request.headers.get('User-Agent', None) if ua: ua_info = httpagentparser.detect(ua) self.eventQ.put(Signal('hit.info', dict(hit_id=self.hit.id, os=ua_info['os']['name'], browser=ua_info['browser']['name']))) try: geoip = self.geoip_reader.country(self.request.remote_ip) logger.info(f"Geo {geoip}") self.eventQ.put(Signal('hit.info', dict(hit_id=self.hit.id, location=geoip.country.name))) except Exception as e: logger.exception(e) logger.info("No geo IP possible") self.eventQ.put(Signal('hit.info', dict(hit_id=self.hit.id, location='Unknown'))) # self.write_message("hello!") # the client sent the message def on_message(self, message): logger.debug(f"recieve: {message}") try: msg = json.loads(message) # TODO: sanitize input: min/max, limit strokes if msg['action'] == 'move': # TODO: min/max input point = [float(msg['direction'][0]),float(msg['direction'][1]), bool(msg['mouse'])] self.strokes.append(point) self.plotterQ.put(point) elif msg['action'] == 'up': logger.info(f'up: {msg}') point = [msg['direction'][0],msg['direction'][1], 1] self.strokes.append(point) elif msg['action'] == 'submit': logger.info(f'submit: {msg}') id = self.submit_strokes() if not id: self.write_message(json.dumps('error')) return self.write_message(json.dumps({ 'action': 'submitted', 'msg': f"Submission ok, please refer to your submission as: {self.hit.uuid}" })) elif msg['action'] == 'down': # not used, implicit in move? pass elif msg['action'] == 'info': self.eventQ.put(Signal('hit.info', dict( hit_id=self.hit.id, resolution=msg['resolution'], browser=msg['browser'] ))) pass else: # self.send({'alert': 'Unknown request: {}'.format(message)}) logger.warn('Unknown request: {}'.format(message)) except Exception as e: # self.send({'alert': 'Invalid request: {}'.format(e)}) logger.exception(e) # client disconnected def on_close(self): self.__class__.rmConnection(self) logger.info(f"Client disconnected: {self.request.remote_ip}") def submit_strokes(self): if len(self.strokes) < 1: return False self.eventQ.put(Signal("server.submit", dict(hit_id = self.hit.id))) if self.config['dummy_plotter']: d = strokes2D(self.strokes) svg = f""" """ filename = self.hit.getImagePath() logger.info(f"Write to {filename}") with open(filename, 'w') as fp: fp.write(svg) # we fake a hit.scanned event self.eventQ.put(Signal('hit.scanned', {'hit_id':self.hit.id})) return self.hit.uuid @classmethod def rmConnection(cls, client): if client not in cls.connections: return cls.connections.remove(client) class StatusWebSocketHandler(tornado.websocket.WebSocketHandler): CORS_ORIGINS = ['localhost'] connections = set() def initialize(self): pass def check_origin(self, origin): parsed_origin = urlparse(origin) # parsed_origin.netloc.lower() gives localhost:3333 valid = any([parsed_origin.hostname.endswith(origin) for origin in self.CORS_ORIGINS]) return valid # the client connected def open(self, p = None): self.__class__.connections.add(self) # client disconnected def on_close(self): self.__class__.rmConnection(self) logger.info(f"Client disconnected: {self.request.remote_ip}") @classmethod def rmConnection(cls, client): if client not in cls.connections: return cls.connections.remove(client) @classmethod def update_for_all(cls, prop, value): for connection in cls.connections: connection.write_message(json.dumps({ 'property': prop, 'value': value })) def strokes2D(strokes): # strokes to a d attribute for a path d = ""; last_stroke = None; cmd = ""; for stroke in strokes: if not last_stroke: d += f"M{stroke[0]},{stroke[1]} " cmd = 'M' else: if last_stroke[2] == 1: d += " m" cmd = 'm' elif cmd != 'l': d+=' l ' cmd = 'l' rel_stroke = [stroke[0] - last_stroke[0], stroke[1] - last_stroke[1]]; d += f"{rel_stroke[0]},{rel_stroke[1]} " last_stroke = stroke; return d; class DrawPageHandler(tornado.web.RequestHandler): def initialize(self, store: HITStore, path: str): self.store = store self.path = path def get(self): try: hit_id = self.get_query_argument('id') hit = self.store.getHitById(hit_id) except Exception: self.write("HIT not found") else: if hit.submit_page_at: self.write("HIT already submitted") return previous_hit = self.store.getLastSubmittedHit() if not previous_hit: # start with basic svg logger.warning("No previous HIT, start from basic svg") image = "/basic.svg" else: image = previous_hit.getImageUrl() logger.info(f"Image url: {image}") self.set_header("Access-Control-Allow-Origin", "*") contents = open(os.path.join(self.path, 'index.html'), 'r').read().replace("{IMAGE_URL}", image) self.write(contents) class StatusPage(): """ Properties for on the status page, which are send over websockets the moment they are altered. """ def __init__(self): self.reset() def reset(self): self.hit_id = None self.worker_id = None self.ip = None self.location = None self.browser = None self.os = None self.resolution = None self.state = None self.fee = None self.hit_created = None self.hit_opened = None def __setattr__(self, name, value): self.__dict__[name] =value StatusWebSocketHandler.update_for_all(name, value) def set(self, name, value): return self.__setattr__(name, value) class Server: """ Server for HIT -> plotter events As well as for the Status interface TODO: change to have the HIT_id as param to the page. Load hit from storage with previous image """ def __init__(self, config, eventQ: Queue, runningEvent: Event, plotterQ: Queue, store: HITStore): self.isRunning = runningEvent self.eventQ = eventQ self.config = config self.logger = logger self.plotterQ = plotterQ # communicate directly to plotter (skip main thread) #self.config['server']['port'] self.web_root = os.path.join('www') self.server_loop = None self.store = store self.statusPage = StatusPage() def start(self): if not os.path.exists('GeoLite2-Country.mmdb'): raise Exception("Please download the GeoLite2 Country database and place the 'GeoLite2-Country.mmdb' file in the project root.") self.geoip_reader = geoip2.database.Reader('GeoLite2-Country.mmdb') try: asyncio.set_event_loop(asyncio.new_event_loop()) application = tornado.web.Application([ (r"/ws(.*)", WebSocketHandler, { 'config': self.config, 'plotterQ': self.plotterQ, 'eventQ': self.eventQ, 'store': self.store, 'geoip_reader': self.geoip_reader }), (r"/status/ws", StatusWebSocketHandler), (r"/draw", DrawPageHandler, dict(store = self.store, path=self.web_root)), (r"/(.*)", StaticFileWithHeaderHandler, {"path": self.web_root}), ], debug=True, autoreload=False) application.listen(self.config['server']['port']) self.server_loop = tornado.ioloop.IOLoop.current() if self.isRunning.is_set(): self.server_loop.start() finally: self.logger.info("Stopping webserver") self.isRunning.clear() def stop(self): if self.server_loop: self.logger.debug("Got call to stop") self.server_loop.asyncio_loop.call_soon_threadsafe(self._stop) def _stop(self): self.server_loop.stop()