guest_worker/sorteerhoed/webserver.py
2020-01-22 19:07:07 +01:00

542 lines
20 KiB
Python

import json
import logging
import os
import tornado.ioloop
import tornado.web
import tornado.websocket
from urllib.parse import urlparse
import magic
from threading import Thread, Event
from queue import Queue, Empty
import asyncio
from sorteerhoed import HITStore
from sorteerhoed.Signal import Signal
import httpagentparser
import geoip2.database
import queue
import datetime
import html
logger = logging.getLogger("sorteerhoed").getChild("webserver")
class DateTimeEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, datetime.datetime):
return o.isoformat(timespec='seconds')
return super().default(self, o)
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
def set_extra_headers(self, path):
"""For subclass to add extra headers to the response"""
if path[-5:] == '.html':
self.set_header("Access-Control-Allow-Origin", "*")
if path[-4:] == '.svg':
self.set_header("Content-Type", "image/svg+xml")
if path[-4:] == '.png':
# in testing, without scanner, images are saved as svg
mime = magic.from_file(os.path.join(self.root, path), mime=True)
print(mime)
if mime == 'image/svg+xml':
self.set_header("Content-Type", "image/svg+xml")
class WebSocketHandler(tornado.websocket.WebSocketHandler):
"""
Websocket from the workers
"""
CORS_ORIGINS = ['localhost', '.mturk.com', 'here.rubenvandeven.com', 'guest.rubenvandeven.com']
connections = set()
def initialize(self, config, plotterQ: Queue, eventQ: Queue, store: HITStore):
self.config = config
self.plotterQ = plotterQ
self.eventQ = eventQ
self.store = store
self.assignment_id = None
self.abandoned = False
def check_origin(self, origin):
parsed_origin = urlparse(origin)
# parsed_origin.netloc.lower() gives localhost:3333
valid = any([parsed_origin.hostname.endswith(origin) for origin in self.CORS_ORIGINS])
return valid
# the client connected
def open(self, p = None):
self.__class__.connections.add(self)
hit_id = int(self.get_query_argument('id'))
if hit_id != self.store.currentHit.id:
self.close()
return
self.hit = self.store.currentHit
# my core assumption about assignment_id was wrong. It is not unique per worker, so we need to merge those
self.assignment_id = str(self.get_query_argument('assignmentId'))
self.assignment_id += '_' + str(self.get_query_argument('workerId'))
self.assignment = self.hit.getLastAssignment()
if self.assignment.assignment_id != self.assignment_id:
raise Exception(f"Opening websocket for invalid assignment {self.assignment_id}")
self.timeout = self.assignment.created_at + datetime.timedelta(seconds=self.store.getHitTimeout())
# timeLeft = (self.timeout - datetime.datetime.utcnow()).total_seconds()
if self.hit.isSubmitted():
raise Exception("Opening websocket for already submitted hit")
#logger.info(f"New client connected: {self.request.remote_ip} for {self.hit.id}/{self.hit.hit_id}")
self.eventQ.put(Signal('server.open', dict(assignment_id=self.assignment_id)))
self.strokes = []
# the client sent the message
def on_message(self, message):
logger.debug(f"recieve: {message}")
if self.assignment_id != self.hit.getLastAssignment().assignment_id:
logger.critical(f"Skip message for non-last assignment {message}")
return
if datetime.datetime.utcnow() > self.timeout:
logger.critical("Close websocket after timeout (abandon?)")
self.close()
return
try:
msg = json.loads(message)
if msg['action'] == 'move':
# TODO: min/max input
point = [float(msg['direction'][0]),float(msg['direction'][1]), bool(msg['mouse'])]
self.strokes.append(point)
self.plotterQ.put(point)
elif msg['action'] == 'up':
logger.info(f'up: {msg}')
point = [msg['direction'][0],msg['direction'][1], 1]
self.strokes.append(point)
elif msg['action'] == 'submit':
logger.info(f'submit: {msg}')
id = self.submit_strokes()
if not id:
self.write_message(json.dumps('error'))
return
#store svg:
d = html.escape(msg['d'])
svg = f"""<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg"
version="1.0" viewBox="0 0 {self.config['scanner']['width']}0 {self.config['scanner']['height']}0" width="{self.config['scanner']['width']}mm" height="{self.config['scanner']['height']}mm" preserveAspectRatio="none">
<path d="{d}" style='stroke:gray;stroke-width:2mm;fill:none;' id="stroke" />
</svg>
"""
with open(self.store.currentHit.getSvgImagePath(), 'w') as fp:
fp.write(svg)
self.write_message(json.dumps({
'action': 'submitted',
'msg': f"Submission ok, please copy this token to your HIT at Mechanical Turk: {self.assignment.uuid}",
'code': str(self.assignment.uuid)
}))
self.close()
elif msg['action'] == 'down':
# not used, implicit in move?
pass
elif msg['action'] == 'info':
self.eventQ.put(Signal('assignment.info', dict(
hit_id=self.hit.id,
assignment_id=self.assignment_id,
resolution=msg['resolution'],
browser=msg['browser']
)))
pass
else:
# self.send({'alert': 'Unknown request: {}'.format(message)})
logger.warn('Unknown request: {}'.format(message))
except Exception as e:
# self.send({'alert': 'Invalid request: {}'.format(e)})
logger.exception(e)
# client disconnected
def on_close(self):
self.__class__.rmConnection(self)
if self.assignment_id:
self.eventQ.put(Signal('server.close', dict(assignment_id=self.assignment_id, abandoned=self.abandoned)))
logger.info(f"Client disconnected: {self.request.remote_ip}")
# TODO: abandon assignment??
def submit_strokes(self):
if len(self.strokes) < 1:
return False
self.eventQ.put(Signal("assignment.submit", dict(
hit_id = self.hit.id,
assignment_id=self.assignment_id)))
# deprecated: now done at scanner method:
# if self.config['dummy_plotter']:
# d = strokes2D(self.strokes)
# svg = f"""<?xml version="1.0" encoding="UTF-8" standalone="no"?>
# <svg viewBox="0 0 600 600"
# xmlns:dc="http://purl.org/dc/elements/1.1/"
# xmlns:cc="http://creativecommons.org/ns#"
# xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
# xmlns:svg="http://www.w3.org/2000/svg"
# xmlns="http://www.w3.org/2000/svg"
# version="1.1"
# >
# <path d="{d}" style="stroke:black;stroke-width:2;fill:none;" />
# </svg>
# """
#
# filename = self.hit.getImagePath()
# logger.info(f"Write to {filename}")
# with open(filename, 'w') as fp:
# fp.write(svg)
# we fake a hit.scanned event
# self.eventQ.put(Signal('hit.scanned', {'hit_id':self.hit.id}))
return self.assignment.uuid
@classmethod
def rmConnection(cls, client):
if client not in cls.connections:
return
cls.connections.remove(client)
@classmethod
def hasConnection(cls, client):
return client in cls.connections
@classmethod
def timeoutConnectionForAssignment(cls, assignment_id):
logger.warn(f"Check timeout for {assignment_id}")
for client in cls.connections:
logger.info(client.assignment_id)
if client.assignment_id == assignment_id:
client.abandoned = True
client.close()
class StatusWebSocketHandler(tornado.websocket.WebSocketHandler):
CORS_ORIGINS = ['localhost']
connections = set()
def initialize(self, statusPage):
self.statusPage = statusPage
def check_origin(self, origin):
parsed_origin = urlparse(origin)
# parsed_origin.netloc.lower() gives localhost:3333
valid = any([parsed_origin.hostname.endswith(origin) for origin in self.CORS_ORIGINS])
return valid
# the client connected
def open(self):
self.__class__.connections.add(self)
limit = 2
if 'all' in self.request.query_arguments:
limit = None
self.write_message(json.dumps(self.statusPage.fetch(limit), cls=DateTimeEncoder))
# client disconnected
def on_close(self):
self.__class__.rmConnection(self)
logger.info(f"Client disconnected: {self.request.remote_ip}")
@classmethod
def rmConnection(cls, client):
if client not in cls.connections:
return
cls.connections.remove(client)
@classmethod
def update_for_all(cls, data):
logger.debug(f"update for all {data}")
for connection in cls.connections:
connection.write_message(json.dumps(data, cls=DateTimeEncoder))
def strokes2D(strokes):
# strokes to a d attribute for a path
d = "";
last_stroke = None;
cmd = "";
for stroke in strokes:
if not last_stroke:
d += f"M{stroke[0]},{stroke[1]} "
cmd = 'M'
else:
if last_stroke[2] == 1:
d += " m"
cmd = 'm'
elif cmd != 'l':
d+=' l '
cmd = 'l'
rel_stroke = [stroke[0] - last_stroke[0], stroke[1] - last_stroke[1]];
d += f"{rel_stroke[0]},{rel_stroke[1]} "
last_stroke = stroke;
return d;
class DrawPageHandler(tornado.web.RequestHandler):
def initialize(self, store: HITStore, eventQ: Queue, path: str, width: int, height: int, draw_width: int, draw_height: int, top_padding: int, left_padding: int, geoip_reader: geoip2.database.Reader):
self.store = store
self.path = path
self.width = width
self.height = height
self.draw_width = draw_width
self.draw_height = draw_height
self.top_padding = top_padding
self.left_padding = left_padding
self.eventQ = eventQ
self.geoip_reader = geoip_reader
def get(self):
try:
hit_id = int(self.get_query_argument('id'))
if hit_id != self.store.currentHit.id:
self.write("Invalid HIT")
return
hit = self.store.currentHit
except Exception:
self.write("HIT not found")
else:
if hit.isSubmitted():
self.write("HIT already submitted")
return
assignmentId = self.get_query_argument('assignmentId', '')
if len(assignmentId):
assignmentId += '_' + str(self.get_query_argument('workerId', ''))
if len(assignmentId) < 1:
logger.critical("Accessing page without assignment id. Allowing it for debug purposes... fingers crossed?")
previewOnly = False
if assignmentId == 'ASSIGNMENT_ID_NOT_AVAILABLE':
previewOnly = True
if len(assignmentId) and not previewOnly:
# process/create assignment
assignment = self.store.currentHit.getAssignmentById(assignmentId)
if not assignment:
# new assignment
logger.warning(f"Create new assignment {assignmentId}")
assignment = self.store.newAssignment(self.store.currentHit, assignmentId)
self.store.saveAssignment(assignment)
logger.info(f"Set close timeout for {self.store.getHitTimeout()}")
Server.loop.asyncio_loop.call_later(self.store.getHitTimeout(), WebSocketHandler.timeoutConnectionForAssignment, assignment.assignment_id)
previous_hit = self.store.getLastSubmittedHit()
if not previous_hit:
# start with basic svg
logger.warning("No previous HIT, start from basic svg")
image = "/basic.svg"
else:
image = previous_hit.getSvgImageUrl()
logger.info(f"Image url: {image}")
self.set_header("Access-Control-Allow-Origin", "*")
contents = open(os.path.join(self.path, 'index.html'), 'r').read()
contents = contents.replace("{IMAGE_URL}", image)\
.replace("{WIDTH}", str(self.width))\
.replace("{HEIGHT}", str(self.height))\
.replace("{DRAW_WIDTH}", str(self.draw_width))\
.replace("{DRAW_HEIGHT}", str(self.draw_height))\
.replace("{TOP_PADDING}", str(self.top_padding))\
.replace("{LEFT_PADDING}", str(self.left_padding))\
.replace("{SCRIPT}", '' if previewOnly else '<script type="text/javascript" src="/assignment.js"></script>')\
.replace("{ASSIGNMENT}", '' if previewOnly else str(assignmentId)) # TODO: fix unsafe inserting of GET variable
self.write(contents)
if 'X-Forwarded-For' in self.request.headers:
ip = self.request.headers['X-Forwarded-For']
else:
ip = self.request.remote_ip
logger.info(f"Request from {ip}")
if not previewOnly:
self.eventQ.put(Signal('hit.assignment', dict(
hit_id=hit.id, ip=ip, assignment_id=assignmentId
)))
self.eventQ.put(Signal('assignment.info', dict(assignment_id=assignmentId, ip=ip)))
try:
geoip = self.geoip_reader.country(ip)
logger.debug(f"Geo {geoip}")
self.eventQ.put(Signal('assignment.info', dict(assignment_id=assignmentId, location=geoip.country.name)))
except Exception as e:
logger.exception(e)
logger.info("No geo IP possible")
self.eventQ.put(Signal('assignment.info', dict(assignment_id=assignmentId, location='Unknown')))
ua = self.request.headers.get('User-Agent', None)
if ua:
ua_info = httpagentparser.detect(ua)
self.eventQ.put(Signal('assignment.info', dict(assignment_id=assignmentId, os=ua_info['os']['name'], browser=ua_info['browser']['name'])))
class BackendHandler(tornado.web.RequestHandler):
def initialize(self, store: HITStore, path: str):
self.store = store
self.path = path
def get(self):
rows = []
# for hit in self.store.getHITs(100):
# if hit.submit_hit_at and hit.accept_time:
# seconds = (hit.submit_hit_at - hit.accept_time).total_seconds()
# duration_m = int(seconds/60)
# duration_s = max(int(seconds%60), 0)
# duration = (f"{duration_m}m" if duration_m else "") + f"{duration_s:02d}s"
# else:
# duration = "-"
#
# fee = f"${hit.fee:.2}" if hit.fee else "-"
#
# rows.append(
# f"""
# <tr><td></td><td>{hit.worker_id}</td>
# <td>{hit.turk_ip}</td>
# <td>{hit.turk_country}</td>
# <td>{fee}</td>
# <td>{hit.accept_time}</td>
# <td>{duration}</td><td></td>
# """
# )
contents = open(os.path.join(self.path, 'backend/backend.html'), 'r').read()
# contents = contents.replace("{{TBODY}}", "".join(rows))
self.write(contents)
class StatusPage():
"""
Properties for on the status page, which are send over websockets the moment
they are altered.
"""
def __init__(self, store: HITStore):
self.store = store
self.store.registerUpdateHook(self)
def update(self, hit = None):
"""
Send the given HIT formatted to the websocket clients
If no hit is given, load the last 2 items
"""
if hit:
data = [hit.toDict()]
else:
hits = self.store.getNewestHits(2)
data = [hit.toDict() for hit in hits]
if Server.loop:
Server.loop.asyncio_loop.call_soon_threadsafe(StatusWebSocketHandler.update_for_all, data)
else:
logger.warn("Status: no server loop to call update command")
def fetch(self, limit = 2):
"""
Fetch latest, used on connection of status page
"""
hits = self.store.getNewestHits(limit)
return [hit.toDict() for hit in hits]
class Server:
"""
Server for HIT -> plotter events
As well as for the Status interface
"""
loop = None
def __init__(self, config, eventQ: Queue, runningEvent: Event, plotterQ: Queue, store: HITStore):
self.isRunning = runningEvent
self.eventQ = eventQ
self.config = config
self.logger = logger
self.plotterQ = plotterQ # communicate directly to plotter (skip main thread)
#self.config['server']['port']
self.web_root = os.path.join('www')
self.server_loop = None
self.store = store
self.statusPage = StatusPage(store)
def start(self):
if not os.path.exists('GeoLite2-Country.mmdb'):
raise Exception("Please download the GeoLite2 Country database and place the 'GeoLite2-Country.mmdb' file in the project root.")
self.geoip_reader = geoip2.database.Reader('GeoLite2-Country.mmdb')
try:
asyncio.set_event_loop(asyncio.new_event_loop())
application = tornado.web.Application([
(r"/ws(.*)", WebSocketHandler, {
'config': self.config,
'plotterQ': self.plotterQ,
'eventQ': self.eventQ,
'store': self.store,
}),
(r"/status/ws", StatusWebSocketHandler, dict(statusPage = self.statusPage)),
(r"/draw", DrawPageHandler,
dict(
store = self.store,
eventQ = self.eventQ,
path=self.web_root,
width=self.config['scanner']['width'],
height=self.config['scanner']['height'],
draw_width=self.config['scanner']['draw_width'],
draw_height=self.config['scanner']['draw_height'],
top_padding=self.config['scanner']['top_padding'],
left_padding=self.config['scanner']['left_padding'],
geoip_reader= self.geoip_reader
)),
(r"/backend", BackendHandler,
dict(
store = self.store,
path=self.web_root,
)),
(r"/frames/(.*)", StaticFileWithHeaderHandler,
{"path": 'scanimation/interfaces/frames'}),
(r"/(.*)", StaticFileWithHeaderHandler,
{"path": self.web_root}),
], debug=True, autoreload=False)
application.listen(self.config['server']['port'])
self.server_loop = tornado.ioloop.IOLoop.current()
Server.loop = self.server_loop
if self.isRunning.is_set():
self.server_loop.start()
finally:
self.logger.info("Stopping webserver")
self.isRunning.clear()
def stop(self):
if self.server_loop:
self.logger.debug("Got call to stop")
self.server_loop.asyncio_loop.call_soon_threadsafe(self._stop)
def _stop(self):
self.server_loop.stop()