fix socket loop

This commit is contained in:
Ruben van de Ven 2024-01-23 15:55:09 +01:00
parent f6648b9c18
commit 7e3ba9acd2

View file

@ -1,7 +1,9 @@
from argparse import Namespace from argparse import Namespace
import asyncio import asyncio
import dataclasses
import errno import errno
import json
import logging import logging
from multiprocessing import Event from multiprocessing import Event
import subprocess import subprocess
@ -15,6 +17,8 @@ import tornado.websocket
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.forwarder") logger = logging.getLogger("trap.forwarder")
@ -112,11 +116,12 @@ class WsRouter:
context = zmq.asyncio.Context() context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB) self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr) self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB) self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.connect(config.zmq_prediction_addr) self.prediction_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'') self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_socket.connect(config.zmq_prediction_addr if not self.config.bypass_prediction else config.zmq_trajectory_addr)
self.application = tornado.web.Application( self.application = tornado.web.Application(
[ [
@ -166,11 +171,16 @@ class WsRouter:
logger.info("Starting prediction forwarder") logger.info("Starting prediction forwarder")
while self.is_running.is_set(): while self.is_running.is_set():
# timeout so that if no events occur, loop can still stop on is_running # timeout so that if no events occur, loop can still stop on is_running
has_event = await self.prediction_socket.poll(timeout=1) has_event = await self.prediction_socket.poll(timeout=1000)
if has_event: if has_event:
msg = await self.prediction_socket.recv_string() try:
frame: Frame = await self.prediction_socket.recv_pyobj()
# tacks = [dataclasses.asdict(h) for t in frame.tracks.values() for t.history in t]
msg = json.dumps(frame.aslist())
logger.debug(f"Forward prediction message of {len(msg)} chars") logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg) WebSocketPredictionHandler.write_to_clients(msg)
except Exception as e:
logger.exception(e)
# die together: # die together:
self.evt_loop.stop() self.evt_loop.stop()