import json
import logging
import os
import tornado.ioloop
import tornado.web
import tornado.websocket
from urllib.parse import urlparse
import magic

from threading import Thread, Event
from queue import Queue, Empty
import asyncio
from sorteerhoed import HITStore
from sorteerhoed.Signal import Signal
import httpagentparser
import geoip2.database
import queue


logger = logging.getLogger("sorteerhoed").getChild("webserver")

class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
    def set_extra_headers(self, path):
        """For subclass to add extra headers to the response"""
        if path[-5:] == '.html':
            self.set_header("Access-Control-Allow-Origin", "*")
        if path[-4:] == '.svg':
            self.set_header("Content-Type", "image/svg+xml")
        if path[-4:] == '.png':
            # in testing, without scanner, images are saved as svg
            mime = magic.from_file(os.path.join(self.root, path), mime=True)
            print(mime)
            if mime == 'image/svg+xml':
                self.set_header("Content-Type", "image/svg+xml")
                
            

class WebSocketHandler(tornado.websocket.WebSocketHandler):
    CORS_ORIGINS = ['localhost', '.mturk.com', 'here.rubenvandeven.com']
    connections = set()
    
    def initialize(self, config, plotterQ: Queue, eventQ: Queue, store: HITStore):
        self.config = config
        self.plotterQ = plotterQ
        self.eventQ = eventQ
        self.store = store

    def check_origin(self, origin):
        parsed_origin = urlparse(origin)
        # parsed_origin.netloc.lower() gives localhost:3333
        valid = any([parsed_origin.hostname.endswith(origin) for origin in self.CORS_ORIGINS])
        return valid

    # the client connected
    def open(self, p = None):
        self.__class__.connections.add(self)
        hit_id = int(self.get_query_argument('id'))
        if hit_id != self.store.currentHit.id:
            self.close()
            return
        
        self.hit = self.store.currentHit
        
        if self.hit.submit_hit_at:
            raise Exception("Opening websocket for already submitted hit")
        
        #logger.info(f"New client connected: {self.request.remote_ip} for {self.hit.id}/{self.hit.hit_id}")
        self.eventQ.put(Signal('server.open', dict(hit_id=self.hit.id)))
        self.strokes = []

        # Gather some initial information:
        ua = self.request.headers.get('User-Agent', None)
        if ua:
            ua_info = httpagentparser.detect(ua)
            self.eventQ.put(Signal('hit.info', dict(hit_id=self.hit.id, os=ua_info['os']['name'], browser=ua_info['browser']['name'])))
        
#         self.write_message("hello!")

    # the client sent the message
    def on_message(self, message):
        logger.debug(f"recieve: {message}")
        try:
            msg = json.loads(message)
            # TODO: sanitize input: min/max, limit strokes
            if msg['action'] == 'move':
                # TODO: min/max input
                point = [float(msg['direction'][0]),float(msg['direction'][1]), bool(msg['mouse'])]
                self.strokes.append(point)
                self.plotterQ.put(point)
                
            elif msg['action'] == 'up':
                logger.info(f'up: {msg}')
                point = [msg['direction'][0],msg['direction'][1], 1]
                self.strokes.append(point)
                
            elif msg['action'] == 'submit':
                logger.info(f'submit: {msg}')
                id = self.submit_strokes()
                if not id:
                    self.write_message(json.dumps('error'))
                    return
                
                self.write_message(json.dumps({
                    'action': 'submitted',
                    'msg': f"Submission ok, please refer to your submission as: {self.hit.uuid}"
                    }))
                self.close()
                
            elif msg['action'] == 'down':
                # not used, implicit in move?
                pass
            elif msg['action'] == 'info':
                self.eventQ.put(Signal('hit.info', dict(
                    hit_id=self.hit.id,
                    resolution=msg['resolution'],
                    browser=msg['browser']
                )))
                pass
            else:
                # self.send({'alert': 'Unknown request: {}'.format(message)})
                logger.warn('Unknown request: {}'.format(message))

        except Exception as e:
            # self.send({'alert': 'Invalid request: {}'.format(e)})
            logger.exception(e)

    # client disconnected
    def on_close(self):
        self.__class__.rmConnection(self)
        logger.info(f"Client disconnected: {self.request.remote_ip}")
    
    def submit_strokes(self):
        if len(self.strokes) < 1:
            return False
        
        self.eventQ.put(Signal("server.submit", dict(hit_id = self.hit.id)))
        
        if self.config['dummy_plotter']:
            d = strokes2D(self.strokes)
            svg = f"""<?xml version="1.0" encoding="UTF-8" standalone="no"?>
            <svg viewBox="0 0 600 600"
           xmlns:dc="http://purl.org/dc/elements/1.1/"
           xmlns:cc="http://creativecommons.org/ns#"
           xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#"
           xmlns:svg="http://www.w3.org/2000/svg"
           xmlns="http://www.w3.org/2000/svg"
           version="1.1"
            >
            <path d="{d}" style="stroke:black;stroke-width:2;fill:none;" />
            </svg>
            """
            
            filename = self.hit.getImagePath()
            logger.info(f"Write to {filename}")
            with open(filename, 'w') as fp:
                fp.write(svg)
        
            # we fake a hit.scanned event
            self.eventQ.put(Signal('hit.scanned', {'hit_id':self.hit.id}))
            
        return self.hit.uuid

    @classmethod
    def rmConnection(cls, client):
        if client not in cls.connections:
            return
        cls.connections.remove(client)


class StatusWebSocketHandler(tornado.websocket.WebSocketHandler):
    CORS_ORIGINS = ['localhost']
    connections = set()
    queue = queue.Queue()
    
    def initialize(self):
        pass

    def check_origin(self, origin):
        parsed_origin = urlparse(origin)
        # parsed_origin.netloc.lower() gives localhost:3333
        valid = any([parsed_origin.hostname.endswith(origin) for origin in self.CORS_ORIGINS])
        return valid

    # the client connected
    def open(self, p = None):
        self.__class__.connections.add(self)
        
    
    # client disconnected
    def on_close(self):
        self.__class__.rmConnection(self)
        logger.info(f"Client disconnected: {self.request.remote_ip}")
    
    @classmethod
    def rmConnection(cls, client):
        if client not in cls.connections:
            return
        cls.connections.remove(client)
        
    @classmethod
    def update_for_all(cls, prop, value):
        logger.debug(f"update for all {prop} {value}")
        for connection in cls.connections:
            connection.write_message(json.dumps({
                'property': prop,
                'value': value
                }))
        
def strokes2D(strokes):
    # strokes to a d attribute for a path
    d = "";
    last_stroke = None;
    cmd = "";
    for stroke in strokes:
        if not last_stroke:
            d += f"M{stroke[0]},{stroke[1]} "
            cmd = 'M'
        else:
            if  last_stroke[2] == 1:
                d += " m"
                cmd = 'm'
            elif cmd != 'l':
                d+=' l '
                cmd = 'l'
            
            rel_stroke = [stroke[0] - last_stroke[0], stroke[1] - last_stroke[1]];
            d += f"{rel_stroke[0]},{rel_stroke[1]} "
        last_stroke = stroke;
    return d;

class DrawPageHandler(tornado.web.RequestHandler):
    def initialize(self, store: HITStore,  eventQ: Queue, path: str, width: int, height: int, draw_width: int, draw_height: int, top_padding: int, left_padding: int, geoip_reader: geoip2.database.Reader):
        self.store = store
        self.path = path
        self.width = width
        self.height = height
        self.draw_width = draw_width
        self.draw_height = draw_height
        self.top_padding = top_padding
        self.left_padding = left_padding
        self.eventQ = eventQ
        self.geoip_reader = geoip_reader
        
    def get(self):
        try:
            hit_id = int(self.get_query_argument('id'))
            if hit_id != self.store.currentHit.id:
                self.write("Invalid HIT")
                return
            hit = self.store.currentHit
        except Exception:
            self.write("HIT not found")
        else:
            if hit.submit_page_at:
                self.write("HIT already submitted")
                return
            
            previous_hit = self.store.getLastSubmittedHit()
            if not previous_hit:
                # start with basic svg
                logger.warning("No previous HIT, start from basic svg")
                image = "/basic.svg"
            else:
                image = previous_hit.getImageUrl()
                logger.info(f"Image url: {image}")
                
            self.set_header("Access-Control-Allow-Origin", "*")
            contents = open(os.path.join(self.path, 'index.html'), 'r').read()
            contents = contents.replace("{IMAGE_URL}", image)\
                                .replace("{WIDTH}", str(self.width))\
                                .replace("{HEIGHT}", str(self.height))\
                                .replace("{DRAW_WIDTH}", str(self.draw_width))\
                                .replace("{DRAW_HEIGHT}", str(self.draw_height))\
                                .replace("{TOP_PADDING}", str(self.top_padding))\
                                .replace("{LEFT_PADDING}", str(self.left_padding))
            self.write(contents)
            
            if 'X-Forwarded-For' in self.request.headers:
                ip = self.request.headers['X-Forwarded-For']
            else:
                ip = self.request.remote_ip
                
            
            logger.info(f"Request from {ip}")
            self.eventQ.put(Signal('hit.info', dict(hit_id=hit.id, ip=ip)))
            
            try:
                geoip = self.geoip_reader.country(ip)
                logger.info(f"Geo {geoip}")
                self.eventQ.put(Signal('hit.info', dict(hit_id=hit.id, location=geoip.country.name)))
            except Exception as e:
                logger.exception(e)
                logger.info("No geo IP possible")
                self.eventQ.put(Signal('hit.info', dict(hit_id=hit.id, location='Unknown')))

class BackendHandler(tornado.web.RequestHandler):
    def initialize(self, store: HITStore, path: str):
        self.store = store
        self.path = path
        
    def get(self):
        rows = []
        for hit in self.store.getHITs(100):
            if hit.submit_hit_at and hit.accept_time:
                seconds = (hit.submit_hit_at - hit.accept_time).total_seconds()
                duration_m = int(seconds/60)
                duration_s = max(int(seconds%60), 0)
                print(duration_m, duration_s)
                duration = (f"{duration_m}m" if duration_m else "") +  f"{duration_s:02d}s"
            else:
                duration = "-"
            
            fee = f"${hit.fee:.2}" if hit.fee else "-"
            
            rows.append(
                f"""
                <tr><td></td><td>{hit.worker_id}</th>
                <td>{hit.turk_ip}</td>
                <td>{hit.turk_country}</td>
                <td>{fee}</td>
                <td>{hit.accept_time}</td>
                <td>{duration}</td><td></td>
                """
            ) 
        
        contents = open(os.path.join(self.path, 'backend.html'), 'r').read()
        contents = contents.replace("{{TBODY}}", "".join(rows))
        self.write(contents)

class StatusPage():
    """
    Properties for on the status page, which are send over websockets the moment
    they are altered.
    """
    def __init__(self):
        self.reset()
        
    def reset(self):
        logger.info("Resetting status")
        self.hit_id = None
        self.worker_id = None
        self.ip = None
        self.location = None
        self.browser = None
        self.os = None
        self.resolution = None
        self.state = None
        self.fee = None
        self.hit_created = None
        self.hit_opened = None
        self.hit_submitted = None
        
    def __setattr__(self, name, value):        
        if name in self.__dict__ and self.__dict__[name] == value:
            logger.debug(f"Ignore setting status of {name}: it already is set to {value}")
            return
        
        self.__dict__[name] =value
        logger.info(f"Update status: {name}: {value}")
        if Server.loop:
            Server.loop.asyncio_loop.call_soon_threadsafe(StatusWebSocketHandler.update_for_all, name, value)
        else:
            logger.warn("Status: no server loop to call update command")
        
        
    def set(self, name, value):
        return self.__setattr__(name, value)

class Server:
    """
    Server for HIT -> plotter events
    As well as for the Status interface
    
    TODO: change to have the HIT_id as param to the page. Load hit from storage with previous image
    """
    loop = None
    
    def __init__(self, config, eventQ: Queue, runningEvent: Event, plotterQ: Queue, store: HITStore):
        self.isRunning = runningEvent
        self.eventQ = eventQ
        self.config = config
        self.logger = logger
        
        self.plotterQ = plotterQ # communicate directly to plotter (skip main thread)
        
        #self.config['server']['port']
        self.web_root = os.path.join('www')
        
        self.server_loop = None
        self.store = store
        self.statusPage = StatusPage()
        
        
    def start(self):
        if not os.path.exists('GeoLite2-Country.mmdb'):
            raise Exception("Please download the GeoLite2 Country database and place the 'GeoLite2-Country.mmdb' file in the project root.")
        
        self.geoip_reader = geoip2.database.Reader('GeoLite2-Country.mmdb')
        
        try:
            asyncio.set_event_loop(asyncio.new_event_loop())
            application = tornado.web.Application([
                (r"/ws(.*)", WebSocketHandler, {
                    'config': self.config,
                    'plotterQ': self.plotterQ,
                    'eventQ': self.eventQ,
                    'store': self.store,
                }),
                (r"/status/ws", StatusWebSocketHandler),
                (r"/draw", DrawPageHandler,
                    dict(
                        store = self.store,
                        eventQ = self.eventQ, 
                        path=self.web_root,
                        width=self.config['scanner']['width'],
                        height=self.config['scanner']['height'],
                        draw_width=self.config['scanner']['draw_width'],
                        draw_height=self.config['scanner']['draw_height'],
                        top_padding=self.config['scanner']['top_padding'],
                        left_padding=self.config['scanner']['left_padding'],
                        geoip_reader= self.geoip_reader
                    )),
                (r"/backend", BackendHandler,
                    dict(
                        store = self.store,
                        path=self.web_root,
                    )),
                (r"/(.*)", StaticFileWithHeaderHandler,
                    {"path": self.web_root}),
            ], debug=True, autoreload=False)
            application.listen(self.config['server']['port'])
            self.server_loop = tornado.ioloop.IOLoop.current()
            Server.loop = self.server_loop
            if self.isRunning.is_set():
                self.server_loop.start()
        finally:
            self.logger.info("Stopping webserver")
            self.isRunning.clear()
            
    def stop(self):
        if self.server_loop:
            self.logger.debug("Got call to stop")
            self.server_loop.asyncio_loop.call_soon_threadsafe(self._stop)

    
    def _stop(self):
        self.server_loop.stop()