guest_worker/sorteerhoed/HITStore.py

345 lines
12 KiB
Python

import logging
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, DateTime, Float
from sqlalchemy.orm import relationship
from sqlalchemy.sql.schema import ForeignKey, Sequence
from sqlalchemy.engine import create_engine
from sqlalchemy.orm.session import sessionmaker, object_session
import datetime
from contextlib import contextmanager
import uuid
import os
import country_converter
from svgpathtools import svg2paths
mainLogger = logging.getLogger("sorteerhoed")
logger = mainLogger.getChild("store")
Base = declarative_base()
cc = country_converter.CountryConverter()
"""
HIT lifetime:
created
accepted
(returned!)
working
awaiting amazon confirmation (submitted on page)
submitted
Actions:
creating Hit (creating hit with scanned image)
Scanning
"""
class HIT(Base):
__tablename__ = 'hits'
id = Column(Integer, Sequence('hit_id'), primary_key=True) # our sequential hit id
hit_id = Column(String(255)) # amazon's hit id
created_at = Column(DateTime, default=datetime.datetime.utcnow)
updated_at = Column(DateTime, default=datetime.datetime.utcnow)
scanned_at = Column(DateTime, default=None)
plotted_at = Column(DateTime, default=None)
deleted_at = Column(DateTime, default=None)
assignments = relationship("Assignment", back_populates="hit", order_by="Assignment.created_at")
fee = Column(Float(precision=2), default=None)
# previous hit so we can load the corrent image
# previous_hit_id = Column(Integer, ForeignKey('hits.id'), default=None)
# previous_hit = relationship("HIT")
def getImagePath(self):
return os.path.join('scanimation/interfaces/frames', f"{self.id:06d}.jpg")
def getImageUrl(self):
return os.path.join('/frames', f"{self.id:06d}.jpg")
def getSvgImageUrl(self):
return f"/scans/{self.id:06d}.svg"
def getSvgImagePath(self):
# os.path.join on svgImageUrl leads to invalid absolute url
return os.path.join(f'www/scans/{self.id:06d}.svg')
def getLastAssignment(self):
if not len(self.assignments):
return None
return self.assignments[-1]
def getAssignmentById(self, assignmentId):
for a in self.assignments:
if a.assignment_id == assignmentId:
return a
return None
def getStatus(self):
assignment = self.getLastAssignment()
if self.deleted_at:
return "deleted"
if not self.hit_id:
return "creating"
if not assignment:
return "awaiting worker"
if self.scanned_at:
return "scanned"
return assignment.getStatus()
def toDict(self) -> dict:
values = {c.name: getattr(self, c.name) for c in self.__table__.columns}
assignment = self.getLastAssignment()
values['assignment'] = assignment.toDict() if assignment else None
values['state'] = self.getStatus()
values['scan_image'] = self.getImageUrl() if self.scanned_at else None
values['svg_image'] = self.getSvgImageUrl() if self.isSubmitted() else None
values['preceding_assignments'] = [a.toShortDict() for a in self.getBasedOnAssignments()]
values['preceding_assignments'].append({
'worker_id': 'Ruben van de Ven & Merijn van Moll',
'turk_country': 'the Netherlands',
'turk_country_code': 'NL'
})
if not values['svg_image'] or not os.path.exists(self.getSvgImagePath()):
values['path_length'] = None
else:
try:
paths, _ = svg2paths(self.getSvgImagePath())
values['path_length'] = round(paths[0].length())
except:
values['path_length'] = None
return values
def delete(self):
self.deleted_at = datetime.datetime.utcnow()
def isSubmitted(self) -> bool:
a = self.getLastAssignment()
if not a:
return False
return bool(a.submit_page_at)
def isConfirmed(self) -> bool:
a = self.getLastAssignment()
if not a:
return False
return bool(a.confirmed_at)
def getBasedOnAssignments(self):
"""
Get preceding assignments, one per worker, excluding the one who did this HIT
"""
assignment = self.getLastAssignment()
session = object_session(self)
q = session.query(Assignment).\
filter(Assignment.submit_page_at < self.created_at).\
group_by(Assignment.worker_id).\
order_by(Assignment.created_at.desc())
if assignment and assignment.worker_id:
q = q.filter(Assignment.worker_id != assignment.worker_id)
return q
class Assignment(Base):
__tablename__ = 'assignments'
id = Column(Integer, Sequence('assignment_id'), primary_key=True) # our sequential hit id
assignment_id = Column(String(255)) # amazon's assignment id
hit_id = Column(Integer, ForeignKey('hits.id')) # our sequential hit id
hit = relationship("HIT", back_populates="assignments")
uuid = Column(String(32), default=lambda : uuid.uuid4().hex)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
updated_at = Column(DateTime, default=datetime.datetime.utcnow)
assignment_id = Column(String(255), default = None)
worker_id = Column(String(255), default = None)
accept_at = Column(DateTime, default=None) # accept time acccording to SQS
# open_page_at = Column(DateTime, default=None)
submit_page_at = Column(DateTime, default=None) # Submit the page
confirmed_at = Column(DateTime, default=None) # validate with UUID when getting Message from Amazon
abandoned_at = Column(DateTime, default=None)
rejected_at = Column(DateTime, default=None)
answer = Column(String(255), default=None)
turk_ip = Column(String(255), default=None)
turk_country = Column(String(255), default=None)
turk_os = Column(String(255), default=None)
turk_browser = Column(String(255), default=None)
def getStatus(self):
if self.rejected_at:
return "rejected"
if self.abandoned_at:
return "abandoned"
if not self.submit_page_at:
return "working"
if not self.confirmed_at:
return "submitted"
return "confirmed"
def toDict(self) -> dict:
values = {c.name: getattr(self, c.name) for c in self.__table__.columns}
if self.turk_country:
values['turk_country_code'] = cc.convert([self.turk_country], to='ISO2')
else:
values['turk_country_code'] = None
return values
def toShortDict(self) -> dict:
values = {
'worker_id': self.worker_id,
'turk_country': self.turk_country
}
if self.turk_country:
values['turk_country_code'] = cc.convert([self.turk_country], to='ISO2')
else:
values['turk_country_code'] = None
return values
class Store:
def __init__(self, db_filename, logLevel=0):
path = os.path.abspath(db_filename)
if logLevel <= logging.DEBUG:
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
needsInitialization = not os.path.exists(path)
self.engine = create_engine('sqlite:///'+path, echo=False, connect_args={'check_same_thread': False})
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)
self.session = self.Session()
self.currentHit = None # mirrors Centralmanagmenet, stored here so we can quickly access it from webserver classes
self.updateHooks = []
# if needsInitialization:
# self.insertInitialContent()
#
# def insertInitialContent(self):
# hit = self.createHIT()
# assignment = self.newAssignment(hit, 'initial')
#
def registerUpdateHook(self, hook):
if hook not in self.updateHooks:
logger.info(f"Register update hook: {hook}")
self.updateHooks.append(hook)
def triggerUpdateHooks(self, hit = None):
for hook in self.updateHooks:
if callable(hook): # it's a method
hook(hit)
else: # assume it's an object
hook.update(hit)
@contextmanager
def getSession(self):
"""Provide a transactional scope around a series of operations."""
try:
yield self.session
self.session.commit()
except:
self.session.rollback()
raise
def getHits(self):
return self.session.query(HIT).order_by(HIT.created_at.desc())
def getHitById(self, hitId):
return self.session.query(HIT).\
filter(HIT.id==hitId).one()
def getHitByRemoteId(self, amazonHitId):
return self.session.query(HIT).\
filter(HIT.hit_id==amazonHitId).one()
def getLastSubmittedHit(self):
return self.session.query(HIT).\
join(Assignment).\
filter(Assignment.submit_page_at!=None).\
order_by(HIT.created_at.desc()).first()
def getNewestHits(self, n = 2) -> list:
q = self.session.query(HIT).\
filter(HIT.deleted_at==None).\
order_by(HIT.created_at.desc())
if n is not None:
q = q.limit(n)
hits = list(q)
# select DESC, because we want latest, then reverse list to get in right order
hits.reverse()
return hits
def createHIT(self) -> HIT:
with self.getSession() as s:
hit = HIT()
s.add(hit)
s.flush()
s.refresh(hit)
logger.info(f"Created HIT {hit.id}")
self.triggerUpdateHooks(hit)
return hit
def newAssignment(self, hit: HIT, assignmentId) -> Assignment:
# TODO: reset() central management if has pending lastAssignment()
with self.getSession() as s:
assignment = Assignment()
assignment.assignment_id = assignmentId
hit.assignments.append(assignment)
s.add(assignment)
s.flush()
s.refresh(hit)
logger.info(f"Created Assignment {assignment.id}")
self.triggerUpdateHooks(hit)
return assignment
def saveHIT(self, hit):
with self.getSession() as s:
logger.info(f"Updating hit! {hit.id}")
# s.flush()
self.triggerUpdateHooks(hit)
def saveAssignment(self, assignment):
with self.getSession() as s:
logger.info(f"Updating assignment! {assignment.id}")
# s.flush()
self.triggerUpdateHooks(assignment.hit)
# def addHIT(self, hit: HIT):
# with self.getSession() as s:
# s.add(hit)
# s.flush()
# s.refresh(hit)
# logger.info(f"Added {hit.id}")
def getAvgDurationOfPreviousNHits(self, n) -> int:
latest_assignments = self.session.query(Assignment).\
filter(Assignment.created_at!=None).\
filter(Assignment.submit_page_at!=None).\
order_by(Assignment.created_at.desc()).limit(n)
durations = []
for assignment in latest_assignments:
durations.append((assignment.submit_page_at - assignment.created_at).total_seconds())
if not len(durations):
return int(2.5*60) # default to 2.5 minutes
return int(sum(durations) / len(durations))
def getEstimatedHitDuration(self):
return self.getAvgDurationOfPreviousNHits(5)
def getHitTimeout(self):
return 160 # max(160, self.getAvgDurationOfPreviousNHits(5)*2)
def getHITs(self, n = 100):
return self.session.query(HIT).\
filter(HIT.submit_hit_at != None).\
order_by(HIT.submit_hit_at.desc()).limit(n)