Gracefully stop all loops on crash of subprocess
This commit is contained in:
parent
9b39d7cd9b
commit
3d34263a71
5 changed files with 77 additions and 30 deletions
|
@ -1,7 +1,9 @@
|
|||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
|
@ -21,8 +23,9 @@ class FrameEmitter:
|
|||
Emit frame in a separate threat so they can be throttled,
|
||||
or thrown away when the rest of the system cannot keep up
|
||||
'''
|
||||
def __init__(self, config: Namespace) -> None:
|
||||
def __init__(self, config: Namespace, is_running: Event) -> None:
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
|
||||
context = zmq.Context()
|
||||
self.frame_sock = context.socket(zmq.PUB)
|
||||
|
@ -36,7 +39,7 @@ class FrameEmitter:
|
|||
frame_duration = 1./fps
|
||||
|
||||
prev_time = time.time()
|
||||
while True:
|
||||
while self.is_running.is_set():
|
||||
ret, img = video.read()
|
||||
|
||||
# seek to 0 if video has finished. Infinite loop
|
||||
|
@ -58,8 +61,8 @@ class FrameEmitter:
|
|||
else:
|
||||
prev_time = new_frame_time
|
||||
|
||||
|
||||
|
||||
|
||||
def run_frame_emitter(config: Namespace):
|
||||
router = FrameEmitter(config)
|
||||
def run_frame_emitter(config: Namespace, is_running: Event):
|
||||
router = FrameEmitter(config, is_running)
|
||||
router.emit_video()
|
|
@ -1,19 +1,32 @@
|
|||
import logging
|
||||
from logging.handlers import SocketHandler
|
||||
from multiprocessing import Process, Queue
|
||||
from multiprocessing import Event, Process, Queue
|
||||
import sys
|
||||
from trap.config import parser
|
||||
from trap.frame_emitter import run_frame_emitter
|
||||
from trap.prediction_server import InferenceServer, run_inference_server
|
||||
from trap.prediction_server import run_prediction_server
|
||||
from trap.socket_forwarder import run_ws_forwarder
|
||||
from trap.tracker import run_tracker
|
||||
|
||||
|
||||
logger = logging.getLogger("trap.plumbing")
|
||||
|
||||
class ExceptionHandlingProcess(Process):
|
||||
|
||||
def run(self):
|
||||
assert 'is_running' in self._kwargs
|
||||
try:
|
||||
super(Process, self).run()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self._kwargs['is_running'].clear()
|
||||
|
||||
|
||||
def start():
|
||||
args = parser.parse_args()
|
||||
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
|
||||
|
||||
# print(args)
|
||||
# exit()
|
||||
logging.basicConfig(
|
||||
level=loglevel,
|
||||
)
|
||||
|
@ -21,6 +34,9 @@ def start():
|
|||
# set per handler, so we can set it lower for the root logger if remote logging is enabled
|
||||
root_logger = logging.getLogger()
|
||||
[h.setLevel(loglevel) for h in root_logger.handlers]
|
||||
|
||||
isRunning = Event()
|
||||
isRunning.set()
|
||||
|
||||
|
||||
if args.remote_log_addr:
|
||||
|
@ -29,15 +45,17 @@ def start():
|
|||
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
|
||||
root_logger.addHandler(socket_handler)
|
||||
|
||||
|
||||
|
||||
# instantiating process with arguments
|
||||
procs = [
|
||||
Process(target=run_ws_forwarder, args=(args,), name='forwarder'),
|
||||
Process(target=run_frame_emitter, args=(args,), name='frame_emitter'),
|
||||
Process(target=run_tracker, args=(args,), name='tracker'),
|
||||
ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
|
||||
ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
|
||||
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
||||
]
|
||||
if not args.bypass_prediction:
|
||||
procs.append(
|
||||
Process(target=run_inference_server, args=(args,), name='inference'),
|
||||
ExceptionHandlingProcess(target=run_prediction_server, kwargs={'config': args, 'is_running':isRunning}, name='inference'),
|
||||
)
|
||||
|
||||
logger.info("start")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# adapted from Trajectron++ online_server.py
|
||||
from argparse import Namespace
|
||||
import logging
|
||||
from multiprocessing import Queue
|
||||
from multiprocessing import Event, Queue
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
|
@ -108,9 +109,13 @@ def get_maps_for_input(input_dict, scene, hyperparams):
|
|||
return maps_dict
|
||||
|
||||
|
||||
class InferenceServer:
|
||||
def __init__(self, config: dict):
|
||||
class PredictionServer:
|
||||
def __init__(self, config: Namespace, is_running: Event):
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
|
||||
if self.config.eval_device == 'cpu':
|
||||
logger.warning("Running on CPU. Specifying --eval_device cuda:0 should dramatically speed up prediction")
|
||||
|
||||
context = zmq.Context()
|
||||
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
||||
|
@ -199,7 +204,7 @@ class InferenceServer:
|
|||
trajectron.set_environment(online_env, init_timestep)
|
||||
|
||||
timestep = init_timestep + 1
|
||||
while True:
|
||||
while self.is_running.is_set():
|
||||
timestep += 1
|
||||
# for timestep in range(init_timestep + 1, eval_scene.timesteps):
|
||||
|
||||
|
@ -350,9 +355,19 @@ class InferenceServer:
|
|||
data = json.dumps(response)
|
||||
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
||||
self.prediction_socket.send_string(data)
|
||||
logger.info('Stopping')
|
||||
|
||||
|
||||
|
||||
def run_inference_server(config):
|
||||
s = InferenceServer(config)
|
||||
def run_prediction_server(config: Namespace, is_running: Event):
|
||||
|
||||
# attempt to trace the warnings coming from pytorch
|
||||
# def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
|
||||
# log = file if hasattr(file,'write') else sys.stderr
|
||||
# traceback.print_stack(file=log)
|
||||
# log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
||||
|
||||
# warnings.showwarning = warn_with_traceback
|
||||
s = PredictionServer(config, is_running)
|
||||
s.run()
|
|
@ -2,6 +2,7 @@
|
|||
from argparse import Namespace
|
||||
import asyncio
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
from typing import Set, Union, Dict, Any
|
||||
from typing_extensions import Self
|
||||
|
||||
|
@ -103,8 +104,9 @@ class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
|
|||
|
||||
|
||||
class WsRouter:
|
||||
def __init__(self, config: Namespace):
|
||||
def __init__(self, config: Namespace, is_running: Event):
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
|
||||
context = zmq.asyncio.Context()
|
||||
self.trajectory_socket = context.socket(zmq.PUB)
|
||||
|
@ -138,25 +140,31 @@ class WsRouter:
|
|||
|
||||
def start(self):
|
||||
|
||||
evt_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(evt_loop)
|
||||
self.evt_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.evt_loop)
|
||||
|
||||
# loop = tornado.ioloop.IOLoop.current()
|
||||
logger.info(f"Listen on {self.config.ws_port}")
|
||||
self.application.listen(self.config.ws_port)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
task = evt_loop.create_task(self.prediction_forwarder())
|
||||
task = self.evt_loop.create_task(self.prediction_forwarder())
|
||||
|
||||
evt_loop.run_forever()
|
||||
self.evt_loop.run_forever()
|
||||
|
||||
|
||||
async def prediction_forwarder(self):
|
||||
logger.info("Starting prediction forwarder")
|
||||
while True:
|
||||
while self.is_running.is_set():
|
||||
msg = await self.prediction_socket.recv_string()
|
||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||
WebSocketPredictionHandler.write_to_clients(msg)
|
||||
|
||||
# die together:
|
||||
self.evt_loop.stop()
|
||||
logger.info('Stopping')
|
||||
|
||||
|
||||
def run_ws_forwarder(config: Namespace):
|
||||
router = WsRouter(config)
|
||||
def run_ws_forwarder(config: Namespace, is_running: Event):
|
||||
router = WsRouter(config, is_running)
|
||||
router.start()
|
|
@ -1,6 +1,7 @@
|
|||
from argparse import Namespace
|
||||
import json
|
||||
import logging
|
||||
from multiprocessing import Event
|
||||
import pickle
|
||||
import time
|
||||
import numpy as np
|
||||
|
@ -20,8 +21,9 @@ Detections = [Detection]
|
|||
logger = logging.getLogger("trap.tracker")
|
||||
|
||||
class Tracker:
|
||||
def __init__(self, config: Namespace):
|
||||
def __init__(self, config: Namespace, is_running: Event):
|
||||
self.config = config
|
||||
self.is_running = is_running
|
||||
|
||||
context = zmq.Context()
|
||||
self.frame_sock = context.socket(zmq.SUB)
|
||||
|
@ -54,7 +56,7 @@ class Tracker:
|
|||
|
||||
|
||||
def track(self):
|
||||
while True:
|
||||
while self.is_running.is_set():
|
||||
msg = self.frame_sock.recv()
|
||||
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
|
||||
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||
|
@ -89,6 +91,7 @@ class Tracker:
|
|||
|
||||
#TODO calculate fps (also for other loops to see asynchonity)
|
||||
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
||||
logger.info('Stopping')
|
||||
|
||||
|
||||
def detect_persons(self, frame) -> Detections:
|
||||
|
@ -128,6 +131,6 @@ class Tracker:
|
|||
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
|
||||
|
||||
|
||||
def run_tracker(config: Namespace):
|
||||
router = Tracker(config)
|
||||
def run_tracker(config: Namespace, is_running: Event):
|
||||
router = Tracker(config, is_running)
|
||||
router.track()
|
Loading…
Reference in a new issue