Trajectron-plus-plus/trajpred/socket_forwarder.py
2023-10-11 16:35:15 +02:00

162 lines
No EOL
5.1 KiB
Python

from argparse import Namespace
import asyncio
import logging
from typing import Set, Union, Dict, Any
from typing_extensions import Self
from urllib.error import HTTPError
import tornado.ioloop
import tornado.web
import tornado.websocket
import zmq
import zmq.asyncio
logger = logging.getLogger("trajpred.forwarder")
class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler):
def initialize(self, zmq_socket: zmq.asyncio.Socket):
self.zmq_socket = zmq_socket
async def on_message(self, message):
logger.info(f"recieve: {message}")
try:
await self.zmq_socket.send_string(message)
# msg = json.loads(message)
except Exception as e:
# self.send({'alert': 'Invalid request: {}'.format(e)})
logger.exception(e)
# self.write_message(u"You said: " + message)
def open(self, p=None):
logger.info(f"connected {self.request.remote_ip}")
# client disconnected
def on_close(self):
logger.info(f"Client disconnected: {self.request.remote_ip}")
class WebSocketPredictionHandler(tornado.websocket.WebSocketHandler):
connections: Set[Self] = set()
def initialize(self, config):
self.config = config
def on_message(self, message):
logger.warning(f"Receiving message on send-only ws handler: {message}")
def open(self, p=None):
logger.info(f"Prediction WS connected {self.request.remote_ip}")
self.__class__.connections.add(self)
# client disconnected
def on_close(self):
self.__class__.rmConnection(self)
logger.info(f"Client disconnected: {self.request.remote_ip}")
@classmethod
def rmConnection(cls, client):
if client not in cls.connections:
return
cls.connections.remove(client)
@classmethod
def hasConnection(cls, client):
return client in cls.connections
@classmethod
def write_to_clients(cls, msg: Union[bytes, str, Dict[str, Any]]):
if msg is None:
logger.critical("Tried to send 'none'")
return
toRemove = []
for client in cls.connections:
try:
client.write_message(msg)
except tornado.websocket.WebSocketClosedError as e:
logger.warning(f"Not properly closed websocket connection")
toRemove.append(client) # If we remove it here from the set we get an exception about changing set size during iteration
for client in toRemove:
cls.rmConnection(client)
class DemoHandler(tornado.web.RequestHandler):
def initialize(self, config: Namespace):
self.config = config
def get(self):
self.render("index.html", ws_port=self.config.ws_port)
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
def set_extra_headers(self, path):
"""For subclass to add extra headers to the response"""
if path[-5:] == ".html":
self.set_header("Access-Control-Allow-Origin", "*")
if path[-4:] == ".svg":
self.set_header("Content-Type", "image/svg+xml")
class WsRouter:
def __init__(self, config: Namespace):
self.config = config
context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.connect(config.zmq_prediction_addr)
self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.application = tornado.web.Application(
[
(
r"/ws/trajectory",
WebSocketTrajectoryHandler,
{
"zmq_socket": self.trajectory_socket
},
),
(
r"/ws/prediction",
WebSocketPredictionHandler,
{
"config": config,
},
),
(r"/", DemoHandler, {"config": config}),
# (r"/(.*)", StaticFileWithHeaderHandler, {"config": config, "index": 'index.html'}),
],
template_path = 'trajpred/web/',
compiled_template_cache=False)
def start(self):
evt_loop = asyncio.new_event_loop()
asyncio.set_event_loop(evt_loop)
# loop = tornado.ioloop.IOLoop.current()
logger.info(f"Listen on {self.config.ws_port}")
self.application.listen(self.config.ws_port)
loop = asyncio.get_event_loop()
task = evt_loop.create_task(self.prediction_forwarder())
evt_loop.run_forever()
async def prediction_forwarder(self):
logger.info("Starting prediction forwarder")
while True:
msg = await self.prediction_socket.recv_string()
logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg)
def run_ws_forwarder(config: Namespace):
router = WsRouter(config)
router.start()