diff --git a/sorteerhoed/HITStore.py b/sorteerhoed/HITStore.py
index 798918d..28b8397 100644
--- a/sorteerhoed/HITStore.py
+++ b/sorteerhoed/HITStore.py
@@ -32,7 +32,7 @@ submitted
Actions:
creating Hit (creating hit with scanned image)
-Scanning
+Scanning
"""
@@ -42,35 +42,34 @@ class HIT(Base):
hit_id = Column(String(255)) # amazon's hit id
created_at = Column(DateTime, default=datetime.datetime.utcnow)
updated_at = Column(DateTime, default=datetime.datetime.utcnow)
- uuid = Column(String(32), default=lambda : uuid.uuid4().hex)
- assignment_id = Column(String(255), default = None)
- worker_id = Column(String(255), default = None)
- accept_time = Column(DateTime, default=None)
- open_page_at = Column(DateTime, default=None)
- submit_page_at = Column(DateTime, default=None)
- submit_hit_at = Column(DateTime, default=None)
- answer = Column(String(255), default=None)
- turk_ip = Column(String(255), default=None)
- turk_country = Column(String(255), default=None)
- turk_screen_width = Column(Integer, default = None)
- turk_screen_height = Column(Integer, default = None)
scanned_at = Column(DateTime, default=None)
+ deleted_at = Column(DateTime, default=None)
+ assignments = relationship("Assignment", back_populates="hit", order_by="Assignment.created_at")
fee = Column(Float(precision=2), default=None)
- abandoned = False
-
-
+
+ # previous hit so we can load the corrent image
+ # previous_hit_id = Column(Integer, ForeignKey('hits.id'), default=None)
+ # previous_hit = relationship("HIT")
+
def getImagePath(self):
return os.path.join('scanimation/interfaces/frames', f"{self.id:06d}.jpg")
-
-# def getImageUrl(self):
-# return f"{self.id}.jpg"
-
+
def getSvgImageUrl(self):
return f"scans/{self.id:06d}.svg"
-
+
def getSvgImagePath(self):
return os.path.join('www', self.getSvgImageUrl())
+ def getLastAssignment(self):
+ if not len(self.assignments):
+ return None
+ return self.assignments[-1]
+
+ def getAssignmentById(self, assignmentId):
+ for a in self.assignments:
+ if a.assignment_id == assignmentId:
+ return
+
def getStatus(self):
if self.scanned_at:
return "completed"
@@ -86,22 +85,44 @@ class HIT(Base):
if self.worker_id:
return "abandoned by worker"
return "awaiting worker"
-
+class Assignment(Base):
+ __tablename__ = 'assignments'
+ id = Column(Integer, Sequence('assignment_id'), primary_key=True) # our sequential hit id
+ assignment_id = Column(String(255)) # amazon's assignment id
+ hit_id = Column(Integer, ForeignKey('hits.id')) # our sequential hit id
+ hit = relationship("HIT", back_populates="assignments")
+ uuid = Column(String(32), default=lambda : uuid.uuid4().hex)
+
+ created_at = Column(DateTime, default=datetime.datetime.utcnow)
+ updated_at = Column(DateTime, default=datetime.datetime.utcnow)
+
+ assignment_id = Column(String(255), default = None)
+ worker_id = Column(String(255), default = None)
+ accept_at = Column(DateTime, default=None)
+ # open_page_at = Column(DateTime, default=None)
+ submit_page_at = Column(DateTime, default=None) # Submit the page
+ confirmed_at = Column(DateTime, default=None) # validate with UUID when getting Message from Amazon
+ abandoned_at = Column(DateTime, default=None)
+ rejected_at = Column(DateTime, default=None)
+
+ answer = Column(String(255), default=None)
+ turk_ip = Column(String(255), default=None)
+ turk_country = Column(String(255), default=None)
class Store:
def __init__(self, db_filename, logLevel=0):
path = os.path.abspath(db_filename)
if logLevel <= logging.DEBUG:
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
-
+
self.engine = create_engine('sqlite:///'+path, echo=False, connect_args={'check_same_thread': False})
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
self.session = self.Session()
-
+
self.currentHit = None # mirrors Centralmanagmenet, stored here so we can quickly access it from webserver classes
-
+
@contextmanager
def getSession(self):
"""Provide a transactional scope around a series of operations."""
@@ -111,25 +132,25 @@ class Store:
except:
self.session.rollback()
raise
-
+
def getHits(self, session):
return self.session.query(Source).order_by(HIT.created_at.desc())
-
+
def getHitById(self, hitId):
return self.session.query(HIT).\
filter(HIT.id==hitId).one()
-
+
def getHitByRemoteId(self, amazonHitId):
return self.session.query(HIT).\
filter(HIT.hit_id==amazonHitId).one()
-
-
+
+
def getLastSubmittedHit(self):
return self.session.query(HIT).\
filter(HIT.submit_page_at!=None).\
order_by(HIT.submit_page_at.desc()).first()
-
- def createHIT(self):
+
+ def createHIT(self) -> HIT:
with self.getSession() as s:
hit = HIT()
s.add(hit)
@@ -137,19 +158,29 @@ class Store:
s.refresh(hit)
logger.info(f"Created HIT {hit.id}")
return hit
-
+
+ def newAssignment(self, hit: HIT) -> Assignment:
+ with self.getSession() as s:
+ assignment = Assignment()
+ hit.assignments.append(assignment)
+ s.add(assignment)
+ s.flush()
+ s.refresh(hit)
+ logger.info(f"Created Assignment {assignment.id}")
+ return assignment
+
def saveHIT(self, hit):
with self.getSession() as s:
logger.info(f"Updating hit! {hit.id}")
# s.flush()
-
+
def addHIT(self, hit: HIT):
with self.getSession() as s:
s.add(hit)
s.flush()
s.refresh(hit)
logger.info(f"Added {hit.id}")
-
+
def getAvgDurationOfPreviousNHits(self, n) -> int:
latest_hits = self.session.query(HIT).\
filter(HIT.submit_hit_at!=None).\
@@ -161,18 +192,18 @@ class Store:
if not len(durations):
return int(2.5*60)
return int(sum(durations) / len(durations))
-
+
def getEstimatedHitDuration(self):
return self.getAvgDurationOfPreviousNHits(5)
-
+
def getHitTimeout(self):
return 160 # max(160, self.getAvgDurationOfPreviousNHits(5)*2)
-
+
def getHITs(self, n = 100):
return self.session.query(HIT).\
filter(HIT.submit_hit_at != None).\
order_by(HIT.submit_hit_at.desc()).limit(n)
-
+
# def rmSource(self, id: int):
# with self.getSession() as session:
# source = session.query(Source).get(id)
@@ -181,7 +212,6 @@ class Store:
# else:
# logging.info(f"Deleting source {source.id}: {source.url}")
# session.delete(source)
-#
+#
# def getRandomNewsItem(self, session) -> NewsItem:
# return session.query(NewsItem).order_by(func.random()).limit(1).first()
-
diff --git a/sorteerhoed/central_management.py b/sorteerhoed/central_management.py
index a57604a..6515097 100644
--- a/sorteerhoed/central_management.py
+++ b/sorteerhoed/central_management.py
@@ -76,6 +76,7 @@ class CentralManagement():
self.lastHitTime = None
self.eventQueue = Queue()
+ self.statusPageQueue = Queue()
self.isRunning = threading.Event()
self.isScanning = threading.Event()
self.scanLock = threading.Lock()
@@ -210,38 +211,67 @@ class CentralManagement():
if signal.name == 'start':
self.makeHit()
self.lastHitTime = datetime.datetime.now()
- pass
+ elif signal.name == 'hit.scan':
+ if signal.params['id'] != self.currentHit.id:
+ self.logger.info(f"Hit.scanned had wrong id: {signal}")
+ continue
+ self.statusPageQueue.add(dict(hit_id=signal.params['id'], transition='scanning'))
+
elif signal.name == 'hit.scanned':
# TODO: wrap up hit & make new HIT
+ if signal.params['id'] != self.currentHit.id:
+ self.logger.info(f"Hit.scanned had wrong id: {signal}")
+ continue
+
self.currentHit.scanned_at = datetime.datetime.utcnow()
- self.server.statusPage.set('state', self.currentHit.getStatus())
time_diff = datetime.datetime.now() - self.lastHitTime
to_wait = 10 - time_diff.total_seconds()
+ self.statusPageQueue.add(dict(hit_id=self.currentHit.id, state='scan'))
+
if to_wait > 0:
self.logger.warn(f"Sleep until next hit: {to_wait}s")
time.sleep(to_wait)
else:
self.logger.info(f"No need to wait: {to_wait}s")
+
self.makeHit()
self.lastHitTime = datetime.datetime.now()
+ elif signal.name == 'hit.creating':
+ self.statusPageQueue.add(dict(hit_id=signal.params['id'], transition='create_hit'))
+ elif signal.name == 'hit.created':
+ self.statusPageQueue.add(dict(hit_id=signal.params['id'], remote_id=signal.params['remote_id'], state='hit'))
+
elif signal.name == 'scan.start':
pass
elif signal.name == 'scan.finished':
+ # probably see hit.scanned
pass
- elif signal.name == 'hit.info':
- if signal.params['hit_id'] != self.currentHit.id:
- self.logger.warning(f"hit.info hit_id != currenthit.id: {signal}")
- continue
- for name, value in signal.params.items():
- if name == 'hit_id':
- continue
- if name == 'ip':
- self.currentHit.turk_ip = value
- if name == 'location':
- self.currentHit.turk_country = value
- self.logger.debug(f'Set status: {name} to {value}')
+ elif signal.name == 'hit.assignment':
+ # Create new assignment
+ if signal.params['hit_id'] != self.currentHit.id:
+ continue
+
+ assignment = self.store.newAssignment(self.currentHit)
+ assignment.assignment_id = signal.params['assignment_id']
+ self.store.saveAssignment(assignment)
+
+ self.statusPageQueue.add(dict(hit_id=self.currentHit.id, assignment_id=assignment.assignment_id, state='assignment'))
+
+ elif signal.name == 'assignment.info':
+ assignment = self.currentHit.getAssignmentById(signal.params['assignment_id'])
+ if not assignment:
+ self.logger.warning(f"assignment.info assignment.id not for current hit assignments: {signal}")
+
+ for name, value in signal.params.items():
+ if name == 'ip':
+ assignment.turk_ip = value
+ if name == 'location':
+ assignment.turk_country = value
+
+ self.logger.debug(f'Set assignment: {name} to {value}')
self.server.statusPage.set(name, value)
+
elif signal.name == 'server.open':
self.currentHit.open_page_at = datetime.datetime.utcnow()
self.store.saveHIT(self.currentHit)
@@ -345,12 +375,13 @@ class CentralManagement():
def makeHit(self):
self.expireCurrentHit() # expire hit if it is there
+ self.eventQueue.put(Signal('hit.creating', {'id': self.currentHit.id if self.currentHit else 'start'}))
+
self.server.statusPage.reset()
self.reloadConfig() # reload new config values if they are set
# self.notPaused.wait()
-
self.currentHit = self.store.createHIT()
self.store.currentHit = self.currentHit
@@ -392,10 +423,12 @@ class CentralManagement():
self.store.saveHIT(self.currentHit)
# TODO: have HITStore/HIT take care of this by emitting a signal
- self.server.statusPage.set('hit_id', new_hit['HIT']['HITId'])
- self.server.statusPage.set('hit_created', self.currentHit.created_at)
- self.server.statusPage.set('fee', f"${self.currentHit.fee:.2f}")
- self.server.statusPage.set('state', self.currentHit.getStatus())
+ # self.server.statusPage.set('hit_id', new_hit['HIT']['HITId'])
+ # self.server.statusPage.set('hit_created', self.currentHit.created_at)
+ # self.server.statusPage.set('fee', f"${self.currentHit.fee:.2f}")
+ # self.server.statusPage.set('state', self.currentHit.getStatus())
+
+ self.eventQueue.put(Signal('hit.created', {'id': self.currentHit.id, 'remote_id': self.currentHit.hit_id}))
# mturk.send_test_event_notification()
if self.config['amazon']['sqs_url']:
@@ -432,8 +465,8 @@ class CentralManagement():
'sudo', 'scanimage', '-d', 'epkowa','--resolution=100',
'-l','25' #y axis, margin from top of the scanner, hence increasing this, moves the scanned image upwards
,'-t','22', # x axis, margin from left side scanner (seen from the outside)
- '-x',str(181),
- '-y',str(245)
+ '-x',str(181),
+ '-y',str(245)
]
proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
# opens connection to scanner, but only starts scanning when output becomes ready:
@@ -469,6 +502,7 @@ class CentralManagement():
filename = self.currentHit.getImagePath()
with self.scanLock:
+ self.eventQueue.put(Signal('hit.scan', {'id':self.currentHit.id}))
self.eventQueue.put(Signal('scan.start'))
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# opens connection to scanner, but only starts scanning when output becomes ready:
@@ -491,7 +525,7 @@ class CentralManagement():
time.sleep(5) # sleep a few seconds for scanner to return to start position
- self.eventQueue.put(Signal('hit.scanned', {'hit_id':self.currentHit.id}))
+ self.eventQueue.put(Signal('hit.scanned', {'id':self.currentHit.id}))
self.eventQueue.put(Signal('scan.finished'))
def setLight(self, on):
diff --git a/sorteerhoed/statemachine.py b/sorteerhoed/statemachine.py
new file mode 100644
index 0000000..64203b6
--- /dev/null
+++ b/sorteerhoed/statemachine.py
@@ -0,0 +1,89 @@
+import datetime
+
+class State():
+ def __init__(self, hit_id):
+ self.time = datetime.datetime.now()
+ self.hit_id = params['hit_id']
+
+ def transition(self, transitionName, params = {}):
+ raise Exception("Not implemented")
+
+class StateMachine:
+ def __init__(self, initalState):
+ self.history = [('init',initialState)]
+
+ def current(self):
+ return self.history[-1][1]
+
+ def transition(self, transitionName, params):
+ # TODO: update Store & Interface
+ if transitionName not in self.current().availableTransitions:
+ raise Exception("Invalid transition")
+
+ newState = self.current().transition(transitionName, params)
+ if not newState:
+ raise RuntimeException(f"Invalid transition {transitionName} for {self.current()}")
+
+ self.history.append((transitionName, newState))
+
+ def getStateForHit(self, hit_id, stateCls = None):
+ states = [s for s in self.history if s[1].hit_id == hit_id and (stateCls is None or isinstance(s[1], stateCls))]
+ if len(states < 1):
+ return None
+ return states[-1]
+
+class HITCreated(State):
+ availableTransitions = ['accept']
+
+ self.state = None
+ self.fee = None
+ self.hit_created = None
+ self.hit_opened = None
+ self.hit_submitted = None
+ def transition(self, transitionName, params = {}):
+ if transitionName == 'accept':
+ return HITAssigned(params['hit_id'], params['assignment_id'])
+
+class HITAssigned(State):
+ availableTransitions = ['reject', 'abandon', 'submit']
+
+ def __init__(self, hit_id):
+ self.assignment_id = None
+ self.worker_id = None
+ self.ip = None
+ self.location = None
+ self.browser = None
+ self.os = None
+ self.resolution = None
+
+ def transition(self, transitionName, params = {}):
+ if transitionName == 'reject' or transitionName == 'abandon':
+ return HITAbandonedRejected(params['hit_id'])
+ if transitionName == 'submit':
+ return HITSubmitted(params['hit_id'])
+
+class HITAbandonedRejected(State):
+ availableTransitions = ['accept']
+ def transition(self, transitionName, params = {}):
+ if transitionName == 'accept':
+ return HITAssigned(params['hit_id'])
+
+class HITSubmitted(State):
+ availableTransitions = ['scan']
+ def transition(self, transitionName, params = {}):
+ if transitionName == 'scan':
+ return Scanning(params['hit_id'])
+
+class Scanning(State):
+ availableTransitions = ['scan_complete', 'scan_failed']
+ def transition(self, transitionName, params = {}):
+ if transitionName == 'scan_complete':
+ return ImageAvailable(params['hit_id'])
+ if transitionName == 'scan_failed':
+ raise Exception("Scan failed, unknown state")
+
+class ImageAvailable(State):
+ availableTransitions = ['create_hit']
+ def transition(self, transitionName, params = {}):
+ if transitionName == 'create_hit':
+ return HITCreated(params['hit_id'])
diff --git a/sorteerhoed/webserver.py b/sorteerhoed/webserver.py
index fe8f2af..ec5b5c2 100644
--- a/sorteerhoed/webserver.py
+++ b/sorteerhoed/webserver.py
@@ -34,13 +34,16 @@ class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
print(mime)
if mime == 'image/svg+xml':
self.set_header("Content-Type", "image/svg+xml")
-
-
+
+
class WebSocketHandler(tornado.websocket.WebSocketHandler):
+ """
+ Websocket from the workers
+ """
CORS_ORIGINS = ['localhost', '.mturk.com', 'here.rubenvandeven.com', 'guest.rubenvandeven.com']
connections = set()
-
+
def initialize(self, config, plotterQ: Queue, eventQ: Queue, store: HITStore):
self.config = config
self.plotterQ = plotterQ
@@ -60,35 +63,34 @@ class WebSocketHandler(tornado.websocket.WebSocketHandler):
if hit_id != self.store.currentHit.id:
self.close()
return
-
+
self.hit = self.store.currentHit
-
+
+
+ self.assignment_id = int(self.get_query_argument('assignment_id'))
+
self.timeout = datetime.datetime.now() + datetime.timedelta(seconds=self.store.getHitTimeout())
-
+
if self.hit.submit_hit_at:
raise Exception("Opening websocket for already submitted hit")
-
+
#logger.info(f"New client connected: {self.request.remote_ip} for {self.hit.id}/{self.hit.hit_id}")
- self.eventQ.put(Signal('server.open', dict(hit_id=self.hit.id)))
+ self.eventQ.put(Signal('server.open', dict(assignment_id=self.assignment_id)))
self.strokes = []
- # Gather some initial information:
- ua = self.request.headers.get('User-Agent', None)
- if ua:
- ua_info = httpagentparser.detect(ua)
- self.eventQ.put(Signal('hit.info', dict(hit_id=self.hit.id, os=ua_info['os']['name'], browser=ua_info['browser']['name'])))
-
-# self.write_message("hello!")
# the client sent the message
def on_message(self, message):
logger.debug(f"recieve: {message}")
-
+
+ if self.assignment_id != self.hit.getLastAssignment().assignment_id:
+ logger.critical(f"Skip message for non-last assignment {message}")
+
if datetime.datetime.now() > self.timeout:
logger.critical("Close websocket after timeout (abandon?)")
self.close()
return
-
+
try:
msg = json.loads(message)
# TODO: sanitize input: min/max, limit strokes
@@ -97,12 +99,12 @@ class WebSocketHandler(tornado.websocket.WebSocketHandler):
point = [float(msg['direction'][0]),float(msg['direction'][1]), bool(msg['mouse'])]
self.strokes.append(point)
self.plotterQ.put(point)
-
+
elif msg['action'] == 'up':
logger.info(f'up: {msg}')
point = [msg['direction'][0],msg['direction'][1], 1]
self.strokes.append(point)
-
+
elif msg['action'] == 'submit':
logger.info(f'submit: {msg}')
id = self.submit_strokes()
@@ -112,24 +114,24 @@ class WebSocketHandler(tornado.websocket.WebSocketHandler):
#store svg:
d = html.escape(msg['d'])
svg = f"""
-
"""
-
+
with open(self.store.currentHit.getSvgImagePath(), 'w') as fp:
fp.write(svg)
-
+
self.write_message(json.dumps({
'action': 'submitted',
'msg': f"Submission ok, please copy this token to your HIT at Mechanical Turk: {self.hit.uuid}",
'code': str(self.hit.uuid)
}))
self.close()
-
+
elif msg['action'] == 'down':
# not used, implicit in move?
pass
@@ -152,13 +154,13 @@ class WebSocketHandler(tornado.websocket.WebSocketHandler):
def on_close(self):
self.__class__.rmConnection(self)
logger.info(f"Client disconnected: {self.request.remote_ip}")
-
+
def submit_strokes(self):
if len(self.strokes) < 1:
return False
-
+
self.eventQ.put(Signal("server.submit", dict(hit_id = self.hit.id)))
-
+
if self.config['dummy_plotter']:
d = strokes2D(self.strokes)
svg = f"""
@@ -173,15 +175,15 @@ class WebSocketHandler(tornado.websocket.WebSocketHandler):