trap/trap/socket_forwarder.py

193 lines
6.4 KiB
Python

from argparse import Namespace
import asyncio
import dataclasses
import errno
import json
import logging
from multiprocessing import Event
import subprocess
from typing import Set, Union, Dict, Any
from typing_extensions import Self
from urllib.error import HTTPError
import tornado.ioloop
import tornado.web
import tornado.websocket
import zmq
import zmq.asyncio
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.forwarder")
class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler):
def initialize(self, zmq_socket: zmq.asyncio.Socket):
self.zmq_socket = zmq_socket
async def on_message(self, message):
logger.debug(f"receive msg")
try:
await self.zmq_socket.send_string(message)
# msg = json.loads(message)
except Exception as e:
# self.send({'alert': 'Invalid request: {}'.format(e)})
logger.exception(e)
# self.write_message(u"You said: " + message)
def open(self, p=None):
logger.info(f"connected {self.request.remote_ip}")
# client disconnected
def on_close(self):
logger.info(f"Client disconnected: {self.request.remote_ip}")
class WebSocketPredictionHandler(tornado.websocket.WebSocketHandler):
connections: Set[Self] = set()
def initialize(self, config):
self.config = config
def on_message(self, message):
logger.warning(f"Receiving message on send-only ws handler: {message}")
def open(self, p=None):
logger.info(f"Prediction WS connected {self.request.remote_ip}")
self.__class__.connections.add(self)
# client disconnected
def on_close(self):
self.__class__.rmConnection(self)
logger.info(f"Client disconnected: {self.request.remote_ip}")
@classmethod
def rmConnection(cls, client):
if client not in cls.connections:
return
cls.connections.remove(client)
@classmethod
def hasConnection(cls, client):
return client in cls.connections
@classmethod
def write_to_clients(cls, msg: Union[bytes, str, Dict[str, Any]]):
if msg is None:
logger.critical("Tried to send 'none'")
return
toRemove = []
for client in cls.connections:
try:
client.write_message(msg)
except tornado.websocket.WebSocketClosedError as e:
logger.warning(f"Not properly closed websocket connection")
toRemove.append(client) # If we remove it here from the set we get an exception about changing set size during iteration
for client in toRemove:
cls.rmConnection(client)
class DemoHandler(tornado.web.RequestHandler):
def initialize(self, config: Namespace):
self.config = config
def get(self):
self.render("index.html", ws_port=self.config.ws_port)
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
def set_extra_headers(self, path):
"""For subclass to add extra headers to the response"""
if path[-5:] == ".html":
self.set_header("Access-Control-Allow-Origin", "*")
if path[-4:] == ".svg":
self.set_header("Content-Type", "image/svg+xml")
class WsRouter:
def __init__(self, config: Namespace, is_running: Event):
self.config = config
self.is_running = is_running
context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB)
logger.info(f'Publish trajectories on {config.zmq_trajectory_addr}')
self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_socket.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
self.application = tornado.web.Application(
[
(
r"/ws/trajectory",
WebSocketTrajectoryHandler,
{
"zmq_socket": self.trajectory_socket
},
),
(
r"/ws/prediction",
WebSocketPredictionHandler,
{
"config": config,
},
),
(r"/", DemoHandler, {"config": config}),
# (r"/(.*)", StaticFileWithHeaderHandler, {"config": config, "index": 'index.html'}),
],
template_path = 'trap/web/',
compiled_template_cache=False)
def start(self):
self.evt_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.evt_loop)
# loop = tornado.ioloop.IOLoop.current()
logger.info(f"Listen on {self.config.ws_port}")
try:
self.application.listen(self.config.ws_port)
except OSError as e:
if e.errno == errno.EADDRINUSE:
logger.critical("Address already in use by process")
subprocess.run(["lsof", "-i", f"tcp:{self.config.ws_port:d}"])
raise
loop = asyncio.get_event_loop()
task = self.evt_loop.create_task(self.prediction_forwarder())
self.evt_loop.run_forever()
async def prediction_forwarder(self):
logger.info("Starting prediction forwarder")
while self.is_running.is_set():
# timeout so that if no events occur, loop can still stop on is_running
has_event = await self.prediction_socket.poll(timeout=1000)
if has_event:
try:
frame: Frame = await self.prediction_socket.recv_pyobj()
# tacks = [dataclasses.asdict(h) for t in frame.tracks.values() for t.history in t]
msg = json.dumps(frame.aslist())
logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg)
except Exception as e:
logger.exception(e)
# die together:
self.evt_loop.stop()
logger.info('Stopping')
def run_ws_forwarder(config: Namespace, is_running: Event):
router = WsRouter(config, is_running)
router.start()