512 lines
19 KiB
Python
512 lines
19 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import tornado.ioloop
|
|
import tornado.web
|
|
import tornado.websocket
|
|
from urllib.parse import urlparse
|
|
import magic
|
|
|
|
from threading import Thread, Event
|
|
from queue import Queue, Empty
|
|
import asyncio
|
|
from sorteerhoed import HITStore
|
|
from sorteerhoed.Signal import Signal
|
|
import httpagentparser
|
|
import geoip2.database
|
|
import queue
|
|
import datetime
|
|
import html
|
|
|
|
|
|
logger = logging.getLogger("sorteerhoed").getChild("webserver")
|
|
|
|
|
|
class DateTimeEncoder(json.JSONEncoder):
|
|
def default(self, o):
|
|
if isinstance(o, datetime.datetime):
|
|
return o.isoformat(timespec='seconds')
|
|
|
|
return super().default(self, o)
|
|
|
|
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
|
|
def set_extra_headers(self, path):
|
|
"""For subclass to add extra headers to the response"""
|
|
if path[-5:] == '.html':
|
|
self.set_header("Access-Control-Allow-Origin", "*")
|
|
if path[-4:] == '.svg':
|
|
self.set_header("Content-Type", "image/svg+xml")
|
|
if path[-4:] == '.png':
|
|
# in testing, without scanner, images are saved as svg
|
|
mime = magic.from_file(os.path.join(self.root, path), mime=True)
|
|
print(mime)
|
|
if mime == 'image/svg+xml':
|
|
self.set_header("Content-Type", "image/svg+xml")
|
|
|
|
|
|
class WebSocketHandler(tornado.websocket.WebSocketHandler):
|
|
"""
|
|
Websocket from the workers
|
|
"""
|
|
CORS_ORIGINS = ['localhost', '.mturk.com', 'here.rubenvandeven.com', 'guest.rubenvandeven.com']
|
|
connections = set()
|
|
|
|
def initialize(self, config, plotterQ: Queue, eventQ: Queue, store: HITStore):
|
|
self.config = config
|
|
self.plotterQ = plotterQ
|
|
self.eventQ = eventQ
|
|
self.store = store
|
|
|
|
def check_origin(self, origin):
|
|
parsed_origin = urlparse(origin)
|
|
# parsed_origin.netloc.lower() gives localhost:3333
|
|
valid = any([parsed_origin.hostname.endswith(origin) for origin in self.CORS_ORIGINS])
|
|
return valid
|
|
|
|
# the client connected
|
|
def open(self, p = None):
|
|
self.__class__.connections.add(self)
|
|
hit_id = int(self.get_query_argument('id'))
|
|
if hit_id != self.store.currentHit.id:
|
|
self.close()
|
|
return
|
|
|
|
self.hit = self.store.currentHit
|
|
|
|
self.assignment_id = str(self.get_query_argument('assignmentId'))
|
|
self.assignment = self.hit.getLastAssignment()
|
|
|
|
if self.assignment.assignment_id != self.assignment_id:
|
|
raise Exception(f"Opening websocket for invalid assignment {self.assignment_id}")
|
|
|
|
self.timeout = datetime.datetime.now() + datetime.timedelta(seconds=self.store.getHitTimeout())
|
|
|
|
if self.hit.isSubmitted():
|
|
raise Exception("Opening websocket for already submitted hit")
|
|
|
|
#logger.info(f"New client connected: {self.request.remote_ip} for {self.hit.id}/{self.hit.hit_id}")
|
|
self.eventQ.put(Signal('server.open', dict(assignment_id=self.assignment_id)))
|
|
self.strokes = []
|
|
|
|
|
|
# the client sent the message
|
|
def on_message(self, message):
|
|
logger.debug(f"recieve: {message}")
|
|
|
|
if self.assignment_id != self.hit.getLastAssignment().assignment_id:
|
|
logger.critical(f"Skip message for non-last assignment {message}")
|
|
return
|
|
|
|
if datetime.datetime.now() > self.timeout:
|
|
logger.critical("Close websocket after timeout (abandon?)")
|
|
self.close()
|
|
return
|
|
|
|
try:
|
|
msg = json.loads(message)
|
|
if msg['action'] == 'move':
|
|
# TODO: min/max input
|
|
point = [float(msg['direction'][0]),float(msg['direction'][1]), bool(msg['mouse'])]
|
|
self.strokes.append(point)
|
|
self.plotterQ.put(point)
|
|
|
|
elif msg['action'] == 'up':
|
|
logger.info(f'up: {msg}')
|
|
point = [msg['direction'][0],msg['direction'][1], 1]
|
|
self.strokes.append(point)
|
|
|
|
elif msg['action'] == 'submit':
|
|
logger.info(f'submit: {msg}')
|
|
id = self.submit_strokes()
|
|
if not id:
|
|
self.write_message(json.dumps('error'))
|
|
return
|
|
#store svg:
|
|
d = html.escape(msg['d'])
|
|
svg = f"""<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
|
<svg
|
|
xmlns:svg="http://www.w3.org/2000/svg"
|
|
xmlns="http://www.w3.org/2000/svg"
|
|
version="1.0" viewBox="0 0 {self.config['scanner']['width']}0 {self.config['scanner']['height']}0" width="{self.config['scanner']['width']}mm" height="{self.config['scanner']['height']}mm" preserveAspectRatio="none">
|
|
<path d="{d}" style='stroke:gray;stroke-width:2mm;fill:none;' id="stroke" />
|
|
</svg>
|
|
"""
|
|
|
|
with open(self.store.currentHit.getSvgImagePath(), 'w') as fp:
|
|
fp.write(svg)
|
|
|
|
self.write_message(json.dumps({
|
|
'action': 'submitted',
|
|
'msg': f"Submission ok, please copy this token to your HIT at Mechanical Turk: {self.assignment.uuid}",
|
|
'code': str(self.assignment.uuid)
|
|
}))
|
|
self.close()
|
|
|
|
elif msg['action'] == 'down':
|
|
# not used, implicit in move?
|
|
pass
|
|
elif msg['action'] == 'info':
|
|
self.eventQ.put(Signal('assignment.info', dict(
|
|
hit_id=self.hit.id,
|
|
assignment_id=self.assignment_id,
|
|
resolution=msg['resolution'],
|
|
browser=msg['browser']
|
|
)))
|
|
pass
|
|
else:
|
|
# self.send({'alert': 'Unknown request: {}'.format(message)})
|
|
logger.warn('Unknown request: {}'.format(message))
|
|
|
|
except Exception as e:
|
|
# self.send({'alert': 'Invalid request: {}'.format(e)})
|
|
logger.exception(e)
|
|
|
|
# client disconnected
|
|
def on_close(self):
|
|
self.__class__.rmConnection(self)
|
|
logger.info(f"Client disconnected: {self.request.remote_ip}")
|
|
# TODO: abandon assignment??
|
|
|
|
def submit_strokes(self):
|
|
if len(self.strokes) < 1:
|
|
return False
|
|
|
|
self.eventQ.put(Signal("assignment.submit", dict(
|
|
hit_id = self.hit.id,
|
|
assignment_id=self.assignment_id)))
|
|
|
|
# deprecated: now done at scanner method:
|
|
# if self.config['dummy_plotter']:
|
|
# d = strokes2D(self.strokes)
|
|
# svg = f"""<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
|
# <svg viewBox="0 0 600 600"
|
|
# xmlns:dc="http://purl.org/dc/elements/1.1/"
|
|
# xmlns:cc="http://creativecommons.org/ns#"
|
|
# xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
|
|
# xmlns:svg="http://www.w3.org/2000/svg"
|
|
# xmlns="http://www.w3.org/2000/svg"
|
|
# version="1.1"
|
|
# >
|
|
# <path d="{d}" style="stroke:black;stroke-width:2;fill:none;" />
|
|
# </svg>
|
|
# """
|
|
#
|
|
# filename = self.hit.getImagePath()
|
|
# logger.info(f"Write to {filename}")
|
|
# with open(filename, 'w') as fp:
|
|
# fp.write(svg)
|
|
|
|
# we fake a hit.scanned event
|
|
# self.eventQ.put(Signal('hit.scanned', {'hit_id':self.hit.id}))
|
|
|
|
return self.assignment.uuid
|
|
|
|
@classmethod
|
|
def rmConnection(cls, client):
|
|
if client not in cls.connections:
|
|
return
|
|
cls.connections.remove(client)
|
|
|
|
|
|
class StatusWebSocketHandler(tornado.websocket.WebSocketHandler):
|
|
CORS_ORIGINS = ['localhost']
|
|
connections = set()
|
|
|
|
def initialize(self, statusPage):
|
|
self.statusPage = statusPage
|
|
|
|
def check_origin(self, origin):
|
|
parsed_origin = urlparse(origin)
|
|
# parsed_origin.netloc.lower() gives localhost:3333
|
|
valid = any([parsed_origin.hostname.endswith(origin) for origin in self.CORS_ORIGINS])
|
|
return valid
|
|
|
|
# the client connected
|
|
def open(self):
|
|
self.__class__.connections.add(self)
|
|
self.write_message(json.dumps(self.statusPage.fetch(), cls=DateTimeEncoder))
|
|
|
|
|
|
# client disconnected
|
|
def on_close(self):
|
|
self.__class__.rmConnection(self)
|
|
logger.info(f"Client disconnected: {self.request.remote_ip}")
|
|
|
|
@classmethod
|
|
def rmConnection(cls, client):
|
|
if client not in cls.connections:
|
|
return
|
|
cls.connections.remove(client)
|
|
|
|
@classmethod
|
|
def update_for_all(cls, data):
|
|
logger.debug(f"update for all {data}")
|
|
for connection in cls.connections:
|
|
connection.write_message(json.dumps(data, cls=DateTimeEncoder))
|
|
|
|
def strokes2D(strokes):
|
|
# strokes to a d attribute for a path
|
|
d = "";
|
|
last_stroke = None;
|
|
cmd = "";
|
|
for stroke in strokes:
|
|
if not last_stroke:
|
|
d += f"M{stroke[0]},{stroke[1]} "
|
|
cmd = 'M'
|
|
else:
|
|
if last_stroke[2] == 1:
|
|
d += " m"
|
|
cmd = 'm'
|
|
elif cmd != 'l':
|
|
d+=' l '
|
|
cmd = 'l'
|
|
|
|
rel_stroke = [stroke[0] - last_stroke[0], stroke[1] - last_stroke[1]];
|
|
d += f"{rel_stroke[0]},{rel_stroke[1]} "
|
|
last_stroke = stroke;
|
|
return d;
|
|
|
|
class DrawPageHandler(tornado.web.RequestHandler):
|
|
def initialize(self, store: HITStore, eventQ: Queue, path: str, width: int, height: int, draw_width: int, draw_height: int, top_padding: int, left_padding: int, geoip_reader: geoip2.database.Reader):
|
|
self.store = store
|
|
self.path = path
|
|
self.width = width
|
|
self.height = height
|
|
self.draw_width = draw_width
|
|
self.draw_height = draw_height
|
|
self.top_padding = top_padding
|
|
self.left_padding = left_padding
|
|
self.eventQ = eventQ
|
|
self.geoip_reader = geoip_reader
|
|
|
|
def get(self):
|
|
try:
|
|
hit_id = int(self.get_query_argument('id'))
|
|
if hit_id != self.store.currentHit.id:
|
|
self.write("Invalid HIT")
|
|
return
|
|
hit = self.store.currentHit
|
|
except Exception:
|
|
self.write("HIT not found")
|
|
else:
|
|
if hit.isSubmitted():
|
|
self.write("HIT already submitted")
|
|
return
|
|
|
|
assignmentId = self.get_query_argument('assignmentId', '')
|
|
if len(assignmentId) < 1:
|
|
logger.critical("Accessing page without assignment id. Allowing it for debug purposes... fingers crossed?")
|
|
|
|
previewOnly = False
|
|
if assignmentId == 'ASSIGNMENT_ID_NOT_AVAILABLE':
|
|
previewOnly = True
|
|
|
|
if len(assignmentId) and not previewOnly:
|
|
# process/create assignment
|
|
assignment = self.store.currentHit.getAssignmentById(assignmentId)
|
|
if not assignment:
|
|
# new assignment
|
|
logger.warning(f"Create new assignment {assignmentId}")
|
|
assignment = self.store.newAssignment(self.store.currentHit, assignmentId)
|
|
self.store.saveAssignment(assignment)
|
|
|
|
previous_hit = self.store.getLastSubmittedHit()
|
|
if not previous_hit:
|
|
# start with basic svg
|
|
logger.warning("No previous HIT, start from basic svg")
|
|
image = "/basic.svg"
|
|
else:
|
|
image = previous_hit.getSvgImageUrl()
|
|
logger.info(f"Image url: {image}")
|
|
|
|
self.set_header("Access-Control-Allow-Origin", "*")
|
|
contents = open(os.path.join(self.path, 'index.html'), 'r').read()
|
|
contents = contents.replace("{IMAGE_URL}", image)\
|
|
.replace("{WIDTH}", str(self.width))\
|
|
.replace("{HEIGHT}", str(self.height))\
|
|
.replace("{DRAW_WIDTH}", str(self.draw_width))\
|
|
.replace("{DRAW_HEIGHT}", str(self.draw_height))\
|
|
.replace("{TOP_PADDING}", str(self.top_padding))\
|
|
.replace("{LEFT_PADDING}", str(self.left_padding))\
|
|
.replace("{SCRIPT}", '' if previewOnly else '<script type="text/javascript" src="/assignment.js"></script>')\
|
|
.replace("{ASSIGNMENT}", '' if previewOnly else str(assignmentId)) # TODO: fix unsafe inserting of GET variable
|
|
|
|
self.write(contents)
|
|
|
|
if 'X-Forwarded-For' in self.request.headers:
|
|
ip = self.request.headers['X-Forwarded-For']
|
|
else:
|
|
ip = self.request.remote_ip
|
|
|
|
|
|
logger.info(f"Request from {ip}")
|
|
|
|
if not previewOnly:
|
|
self.eventQ.put(Signal('hit.assignment', dict(
|
|
hit_id=hit.id, ip=ip, assignment_id=assignmentId
|
|
)))
|
|
|
|
self.eventQ.put(Signal('assignment.info', dict(assignment_id=assignmentId, ip=ip)))
|
|
|
|
try:
|
|
geoip = self.geoip_reader.country(ip)
|
|
logger.debug(f"Geo {geoip}")
|
|
self.eventQ.put(Signal('assignment.info', dict(assignment_id=assignmentId, location=geoip.country.name)))
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
logger.info("No geo IP possible")
|
|
self.eventQ.put(Signal('assignment.info', dict(assignment_id=assignmentId, location='Unknown')))
|
|
|
|
ua = self.request.headers.get('User-Agent', None)
|
|
if ua:
|
|
ua_info = httpagentparser.detect(ua)
|
|
self.eventQ.put(Signal('assignment.info', dict(assignment_id=assignmentId, os=ua_info['os']['name'], browser=ua_info['browser']['name'])))
|
|
|
|
class BackendHandler(tornado.web.RequestHandler):
|
|
def initialize(self, store: HITStore, path: str):
|
|
self.store = store
|
|
self.path = path
|
|
|
|
def get(self):
|
|
rows = []
|
|
for hit in self.store.getHITs(100):
|
|
if hit.submit_hit_at and hit.accept_time:
|
|
seconds = (hit.submit_hit_at - hit.accept_time).total_seconds()
|
|
duration_m = int(seconds/60)
|
|
duration_s = max(int(seconds%60), 0)
|
|
duration = (f"{duration_m}m" if duration_m else "") + f"{duration_s:02d}s"
|
|
else:
|
|
duration = "-"
|
|
|
|
fee = f"${hit.fee:.2}" if hit.fee else "-"
|
|
|
|
rows.append(
|
|
f"""
|
|
<tr><td></td><td>{hit.worker_id}</td>
|
|
<td>{hit.turk_ip}</td>
|
|
<td>{hit.turk_country}</td>
|
|
<td>{fee}</td>
|
|
<td>{hit.accept_time}</td>
|
|
<td>{duration}</td><td></td>
|
|
"""
|
|
)
|
|
|
|
contents = open(os.path.join(self.path, 'backend.html'), 'r').read()
|
|
contents = contents.replace("{{TBODY}}", "".join(rows))
|
|
self.write(contents)
|
|
|
|
class StatusPage():
|
|
"""
|
|
Properties for on the status page, which are send over websockets the moment
|
|
they are altered.
|
|
"""
|
|
|
|
def __init__(self, store: HITStore):
|
|
self.store = store
|
|
self.store.registerUpdateHook(self)
|
|
|
|
def update(self, hit = None):
|
|
"""
|
|
Send the given HIT formatted to the websocket clients
|
|
|
|
If no hit is given, load the last 2 items
|
|
"""
|
|
if hit:
|
|
data = [hit.toDict()]
|
|
else:
|
|
hits = self.store.getNewestHits(2)
|
|
data = [hit.toDict() for hit in hits]
|
|
|
|
if Server.loop:
|
|
Server.loop.asyncio_loop.call_soon_threadsafe(StatusWebSocketHandler.update_for_all, data)
|
|
else:
|
|
logger.warn("Status: no server loop to call update command")
|
|
|
|
def fetch(self):
|
|
"""
|
|
Fetch latest, used on connection of status page
|
|
"""
|
|
hits = self.store.getNewestHits(2)
|
|
return [hit.toDict() for hit in hits]
|
|
|
|
|
|
|
|
class Server:
|
|
"""
|
|
Server for HIT -> plotter events
|
|
As well as for the Status interface
|
|
"""
|
|
loop = None
|
|
|
|
def __init__(self, config, eventQ: Queue, runningEvent: Event, plotterQ: Queue, store: HITStore):
|
|
self.isRunning = runningEvent
|
|
self.eventQ = eventQ
|
|
self.config = config
|
|
self.logger = logger
|
|
|
|
self.plotterQ = plotterQ # communicate directly to plotter (skip main thread)
|
|
|
|
#self.config['server']['port']
|
|
self.web_root = os.path.join('www')
|
|
|
|
self.server_loop = None
|
|
self.store = store
|
|
self.statusPage = StatusPage(store)
|
|
|
|
|
|
def start(self):
|
|
if not os.path.exists('GeoLite2-Country.mmdb'):
|
|
raise Exception("Please download the GeoLite2 Country database and place the 'GeoLite2-Country.mmdb' file in the project root.")
|
|
|
|
self.geoip_reader = geoip2.database.Reader('GeoLite2-Country.mmdb')
|
|
|
|
try:
|
|
asyncio.set_event_loop(asyncio.new_event_loop())
|
|
application = tornado.web.Application([
|
|
(r"/ws(.*)", WebSocketHandler, {
|
|
'config': self.config,
|
|
'plotterQ': self.plotterQ,
|
|
'eventQ': self.eventQ,
|
|
'store': self.store,
|
|
}),
|
|
(r"/status/ws", StatusWebSocketHandler, dict(statusPage = self.statusPage)),
|
|
(r"/draw", DrawPageHandler,
|
|
dict(
|
|
store = self.store,
|
|
eventQ = self.eventQ,
|
|
path=self.web_root,
|
|
width=self.config['scanner']['width'],
|
|
height=self.config['scanner']['height'],
|
|
draw_width=self.config['scanner']['draw_width'],
|
|
draw_height=self.config['scanner']['draw_height'],
|
|
top_padding=self.config['scanner']['top_padding'],
|
|
left_padding=self.config['scanner']['left_padding'],
|
|
geoip_reader= self.geoip_reader
|
|
)),
|
|
(r"/backend", BackendHandler,
|
|
dict(
|
|
store = self.store,
|
|
path=self.web_root,
|
|
)),
|
|
(r"/frames/(.*)", StaticFileWithHeaderHandler,
|
|
{"path": 'scanimation/interfaces/frames'}),
|
|
(r"/(.*)", StaticFileWithHeaderHandler,
|
|
{"path": self.web_root}),
|
|
], debug=True, autoreload=False)
|
|
application.listen(self.config['server']['port'])
|
|
self.server_loop = tornado.ioloop.IOLoop.current()
|
|
Server.loop = self.server_loop
|
|
if self.isRunning.is_set():
|
|
self.server_loop.start()
|
|
finally:
|
|
self.logger.info("Stopping webserver")
|
|
self.isRunning.clear()
|
|
|
|
def stop(self):
|
|
if self.server_loop:
|
|
self.logger.debug("Got call to stop")
|
|
self.server_loop.asyncio_loop.call_soon_threadsafe(self._stop)
|
|
|
|
|
|
def _stop(self):
|
|
self.server_loop.stop()
|