163 lines
5.1 KiB
Python
163 lines
5.1 KiB
Python
import logging
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy import Column, Integer, String, DateTime
|
|
from sqlalchemy.orm import relationship
|
|
from sqlalchemy.sql.schema import ForeignKey, Sequence
|
|
from sqlalchemy.engine import create_engine
|
|
from sqlalchemy.orm.session import sessionmaker
|
|
import datetime
|
|
from contextlib import contextmanager
|
|
import uuid
|
|
import os
|
|
import coloredlogs
|
|
import argparse
|
|
from sqlalchemy.sql.functions import func
|
|
|
|
mainLogger = logging.getLogger("sorteerhoed")
|
|
logger = mainLogger.getChild("store")
|
|
|
|
Base = declarative_base()
|
|
|
|
"""
|
|
HIT lifetime:
|
|
|
|
|
|
created
|
|
accepted
|
|
(returned!)
|
|
working
|
|
awaiting amazon confirmation (submitted on page)
|
|
submitted
|
|
|
|
Actions:
|
|
creating Hit (creating hit with scanned image)
|
|
|
|
Scanning
|
|
|
|
"""
|
|
|
|
class HIT(Base):
|
|
__tablename__ = 'hits'
|
|
id = Column(Integer, Sequence('hit_id'), primary_key=True) # our sequential hit id
|
|
hit_id = Column(String(255)) # amazon's hit id
|
|
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
|
updated_at = Column(DateTime, default=datetime.datetime.utcnow)
|
|
uuid = Column(String(32), default=lambda : uuid.uuid4().hex)
|
|
assignment_id = Column(String(255), default = None)
|
|
worker_id = Column(String(255), default = None)
|
|
accept_time = Column(DateTime, default=None)
|
|
open_page_at = Column(DateTime, default=None)
|
|
submit_page_at = Column(DateTime, default=None)
|
|
submit_hit_at = Column(DateTime, default=None)
|
|
answer = Column(String(255), default=None)
|
|
turk_ip = Column(String(255), default=None)
|
|
turk_country = Column(String(255), default=None)
|
|
turk_screen_width = Column(Integer, default = None)
|
|
turk_screen_height = Column(Integer, default = None)
|
|
scanned_at = Column(DateTime, default=None)
|
|
|
|
|
|
def getImagePath(self):
|
|
return os.path.join('www', self.getImageUrl())
|
|
|
|
def getImageUrl(self):
|
|
return f"scans/{self.id}.png"
|
|
|
|
def getStatus(self):
|
|
if self.scanned_at:
|
|
return "completed"
|
|
if self.submit_hit_at:
|
|
return "submission confirmed"
|
|
if self.submit_page_at:
|
|
return "submitted by worker"
|
|
if self.open_page_at:
|
|
return "started working"
|
|
if self.accept_time:
|
|
return "accepted by worker"
|
|
return "created"
|
|
|
|
|
|
|
|
class Store:
|
|
def __init__(self, db_filename, logLevel=0):
|
|
path = os.path.abspath(db_filename)
|
|
if logLevel <= logging.DEBUG:
|
|
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
|
|
|
|
self.engine = create_engine('sqlite:///'+path, echo=False, connect_args={'check_same_thread': False})
|
|
Base.metadata.create_all(self.engine)
|
|
self.Session = sessionmaker(bind=self.engine)
|
|
self.session = self.Session()
|
|
|
|
@contextmanager
|
|
def getSession(self):
|
|
"""Provide a transactional scope around a series of operations."""
|
|
try:
|
|
yield self.session
|
|
self.session.commit()
|
|
except:
|
|
self.session.rollback()
|
|
raise
|
|
|
|
def getHits(self, session):
|
|
return self.session.query(Source).order_by(HIT.created_at.desc())
|
|
|
|
def getHitById(self, hitId):
|
|
return self.session.query(HIT).\
|
|
filter(HIT.id==hitId).one()
|
|
|
|
def getHitByRemoteId(self, amazonHitId):
|
|
return self.session.query(HIT).\
|
|
filter(HIT.hit_id==amazonHitId).one()
|
|
|
|
|
|
def getLastSubmittedHit(self):
|
|
return self.session.query(HIT).\
|
|
filter(HIT.submit_page_at!=None).\
|
|
order_by(HIT.submit_page_at.desc()).first()
|
|
|
|
def createHIT(self):
|
|
with self.getSession() as s:
|
|
hit = HIT()
|
|
s.add(hit)
|
|
s.flush()
|
|
s.refresh(hit)
|
|
logger.info(f"Created HIT {hit.id}")
|
|
return hit
|
|
|
|
def saveHIT(self, hit):
|
|
with self.getSession() as s:
|
|
logger.info(f"Updating hit! {hit.id}")
|
|
# s.flush()
|
|
|
|
def addHIT(self, hit: HIT):
|
|
with self.getSession() as s:
|
|
s.add(hit)
|
|
s.flush()
|
|
s.refresh(hit)
|
|
logger.info(f"Added {hit.id}")
|
|
|
|
def getAvgDurationOfPreviousNHits(self, n) -> int:
|
|
latest_hits = self.session.query(HIT).\
|
|
filter(HIT.submit_hit_at!=None).\
|
|
filter(HIT.accept_time!=None).\
|
|
order_by(HIT.submit_hit_at.desc()).limit(n)
|
|
durations = []
|
|
for hit in latest_hits:
|
|
durations.append((hit.submit_hit_at - hit.accept_time).total_seconds())
|
|
if not len(durations):
|
|
return int(2.5*60)
|
|
return int(sum(durations) / len(durations))
|
|
|
|
# def rmSource(self, id: int):
|
|
# with self.getSession() as session:
|
|
# source = session.query(Source).get(id)
|
|
# if not source:
|
|
# logging.warning(f"Source nr {id} not found")
|
|
# else:
|
|
# logging.info(f"Deleting source {source.id}: {source.url}")
|
|
# session.delete(source)
|
|
#
|
|
# def getRandomNewsItem(self, session) -> NewsItem:
|
|
# return session.query(NewsItem).order_by(func.random()).limit(1).first()
|
|
|