Gracefully stop all loops on crash of subprocess

This commit is contained in:
Ruben van de Ven 2023-10-16 11:45:37 +02:00
parent 9b39d7cd9b
commit 3d34263a71
5 changed files with 77 additions and 30 deletions

View file

@ -1,7 +1,9 @@
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field from dataclasses import dataclass, field
import logging import logging
from multiprocessing import Event
import pickle import pickle
import sys
import time import time
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@ -21,8 +23,9 @@ class FrameEmitter:
Emit frame in a separate threat so they can be throttled, Emit frame in a separate threat so they can be throttled,
or thrown away when the rest of the system cannot keep up or thrown away when the rest of the system cannot keep up
''' '''
def __init__(self, config: Namespace) -> None: def __init__(self, config: Namespace, is_running: Event) -> None:
self.config = config self.config = config
self.is_running = is_running
context = zmq.Context() context = zmq.Context()
self.frame_sock = context.socket(zmq.PUB) self.frame_sock = context.socket(zmq.PUB)
@ -36,7 +39,7 @@ class FrameEmitter:
frame_duration = 1./fps frame_duration = 1./fps
prev_time = time.time() prev_time = time.time()
while True: while self.is_running.is_set():
ret, img = video.read() ret, img = video.read()
# seek to 0 if video has finished. Infinite loop # seek to 0 if video has finished. Infinite loop
@ -58,8 +61,8 @@ class FrameEmitter:
else: else:
prev_time = new_frame_time prev_time = new_frame_time
def run_frame_emitter(config: Namespace): def run_frame_emitter(config: Namespace, is_running: Event):
router = FrameEmitter(config) router = FrameEmitter(config, is_running)
router.emit_video() router.emit_video()

View file

@ -1,19 +1,32 @@
import logging import logging
from logging.handlers import SocketHandler from logging.handlers import SocketHandler
from multiprocessing import Process, Queue from multiprocessing import Event, Process, Queue
import sys
from trap.config import parser from trap.config import parser
from trap.frame_emitter import run_frame_emitter from trap.frame_emitter import run_frame_emitter
from trap.prediction_server import InferenceServer, run_inference_server from trap.prediction_server import run_prediction_server
from trap.socket_forwarder import run_ws_forwarder from trap.socket_forwarder import run_ws_forwarder
from trap.tracker import run_tracker from trap.tracker import run_tracker
logger = logging.getLogger("trap.plumbing") logger = logging.getLogger("trap.plumbing")
class ExceptionHandlingProcess(Process):
def run(self):
assert 'is_running' in self._kwargs
try:
super(Process, self).run()
except Exception as e:
logger.exception(e)
self._kwargs['is_running'].clear()
def start(): def start():
args = parser.parse_args() args = parser.parse_args()
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
# print(args)
# exit()
logging.basicConfig( logging.basicConfig(
level=loglevel, level=loglevel,
) )
@ -21,6 +34,9 @@ def start():
# set per handler, so we can set it lower for the root logger if remote logging is enabled # set per handler, so we can set it lower for the root logger if remote logging is enabled
root_logger = logging.getLogger() root_logger = logging.getLogger()
[h.setLevel(loglevel) for h in root_logger.handlers] [h.setLevel(loglevel) for h in root_logger.handlers]
isRunning = Event()
isRunning.set()
if args.remote_log_addr: if args.remote_log_addr:
@ -29,15 +45,17 @@ def start():
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port) socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
root_logger.addHandler(socket_handler) root_logger.addHandler(socket_handler)
# instantiating process with arguments # instantiating process with arguments
procs = [ procs = [
Process(target=run_ws_forwarder, args=(args,), name='forwarder'), ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
Process(target=run_frame_emitter, args=(args,), name='frame_emitter'), ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
Process(target=run_tracker, args=(args,), name='tracker'), ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
] ]
if not args.bypass_prediction: if not args.bypass_prediction:
procs.append( procs.append(
Process(target=run_inference_server, args=(args,), name='inference'), ExceptionHandlingProcess(target=run_prediction_server, kwargs={'config': args, 'is_running':isRunning}, name='inference'),
) )
logger.info("start") logger.info("start")

View file

@ -1,6 +1,7 @@
# adapted from Trajectron++ online_server.py # adapted from Trajectron++ online_server.py
from argparse import Namespace
import logging import logging
from multiprocessing import Queue from multiprocessing import Event, Queue
import os import os
import pickle import pickle
import time import time
@ -108,9 +109,13 @@ def get_maps_for_input(input_dict, scene, hyperparams):
return maps_dict return maps_dict
class InferenceServer: class PredictionServer:
def __init__(self, config: dict): def __init__(self, config: Namespace, is_running: Event):
self.config = config self.config = config
self.is_running = is_running
if self.config.eval_device == 'cpu':
logger.warning("Running on CPU. Specifying --eval_device cuda:0 should dramatically speed up prediction")
context = zmq.Context() context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB) self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
@ -199,7 +204,7 @@ class InferenceServer:
trajectron.set_environment(online_env, init_timestep) trajectron.set_environment(online_env, init_timestep)
timestep = init_timestep + 1 timestep = init_timestep + 1
while True: while self.is_running.is_set():
timestep += 1 timestep += 1
# for timestep in range(init_timestep + 1, eval_scene.timesteps): # for timestep in range(init_timestep + 1, eval_scene.timesteps):
@ -350,9 +355,19 @@ class InferenceServer:
data = json.dumps(response) data = json.dumps(response)
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)") logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
self.prediction_socket.send_string(data) self.prediction_socket.send_string(data)
logger.info('Stopping')
def run_inference_server(config): def run_prediction_server(config: Namespace, is_running: Event):
s = InferenceServer(config)
# attempt to trace the warnings coming from pytorch
# def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
# log = file if hasattr(file,'write') else sys.stderr
# traceback.print_stack(file=log)
# log.write(warnings.formatwarning(message, category, filename, lineno, line))
# warnings.showwarning = warn_with_traceback
s = PredictionServer(config, is_running)
s.run() s.run()

View file

@ -2,6 +2,7 @@
from argparse import Namespace from argparse import Namespace
import asyncio import asyncio
import logging import logging
from multiprocessing import Event
from typing import Set, Union, Dict, Any from typing import Set, Union, Dict, Any
from typing_extensions import Self from typing_extensions import Self
@ -103,8 +104,9 @@ class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
class WsRouter: class WsRouter:
def __init__(self, config: Namespace): def __init__(self, config: Namespace, is_running: Event):
self.config = config self.config = config
self.is_running = is_running
context = zmq.asyncio.Context() context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB) self.trajectory_socket = context.socket(zmq.PUB)
@ -138,25 +140,31 @@ class WsRouter:
def start(self): def start(self):
evt_loop = asyncio.new_event_loop() self.evt_loop = asyncio.new_event_loop()
asyncio.set_event_loop(evt_loop) asyncio.set_event_loop(self.evt_loop)
# loop = tornado.ioloop.IOLoop.current() # loop = tornado.ioloop.IOLoop.current()
logger.info(f"Listen on {self.config.ws_port}") logger.info(f"Listen on {self.config.ws_port}")
self.application.listen(self.config.ws_port) self.application.listen(self.config.ws_port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
task = evt_loop.create_task(self.prediction_forwarder()) task = self.evt_loop.create_task(self.prediction_forwarder())
evt_loop.run_forever() self.evt_loop.run_forever()
async def prediction_forwarder(self): async def prediction_forwarder(self):
logger.info("Starting prediction forwarder") logger.info("Starting prediction forwarder")
while True: while self.is_running.is_set():
msg = await self.prediction_socket.recv_string() msg = await self.prediction_socket.recv_string()
logger.debug(f"Forward prediction message of {len(msg)} chars") logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg) WebSocketPredictionHandler.write_to_clients(msg)
# die together:
self.evt_loop.stop()
logger.info('Stopping')
def run_ws_forwarder(config: Namespace): def run_ws_forwarder(config: Namespace, is_running: Event):
router = WsRouter(config) router = WsRouter(config, is_running)
router.start() router.start()

View file

@ -1,6 +1,7 @@
from argparse import Namespace from argparse import Namespace
import json import json
import logging import logging
from multiprocessing import Event
import pickle import pickle
import time import time
import numpy as np import numpy as np
@ -20,8 +21,9 @@ Detections = [Detection]
logger = logging.getLogger("trap.tracker") logger = logging.getLogger("trap.tracker")
class Tracker: class Tracker:
def __init__(self, config: Namespace): def __init__(self, config: Namespace, is_running: Event):
self.config = config self.config = config
self.is_running = is_running
context = zmq.Context() context = zmq.Context()
self.frame_sock = context.socket(zmq.SUB) self.frame_sock = context.socket(zmq.SUB)
@ -54,7 +56,7 @@ class Tracker:
def track(self): def track(self):
while True: while self.is_running.is_set():
msg = self.frame_sock.recv() msg = self.frame_sock.recv()
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s") # logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
@ -89,6 +91,7 @@ class Tracker:
#TODO calculate fps (also for other loops to see asynchonity) #TODO calculate fps (also for other loops to see asynchonity)
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display # fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
logger.info('Stopping')
def detect_persons(self, frame) -> Detections: def detect_persons(self, frame) -> Detections:
@ -128,6 +131,6 @@ class Tracker:
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections] return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
def run_tracker(config: Namespace): def run_tracker(config: Namespace, is_running: Event):
router = Tracker(config) router = Tracker(config, is_running)
router.track() router.track()