diff --git a/trap/frame_emitter.py b/trap/frame_emitter.py index bbcee05..12579d8 100644 --- a/trap/frame_emitter.py +++ b/trap/frame_emitter.py @@ -1,7 +1,9 @@ from argparse import Namespace from dataclasses import dataclass, field import logging +from multiprocessing import Event import pickle +import sys import time from typing import Optional import numpy as np @@ -21,8 +23,9 @@ class FrameEmitter: Emit frame in a separate threat so they can be throttled, or thrown away when the rest of the system cannot keep up ''' - def __init__(self, config: Namespace) -> None: + def __init__(self, config: Namespace, is_running: Event) -> None: self.config = config + self.is_running = is_running context = zmq.Context() self.frame_sock = context.socket(zmq.PUB) @@ -36,7 +39,7 @@ class FrameEmitter: frame_duration = 1./fps prev_time = time.time() - while True: + while self.is_running.is_set(): ret, img = video.read() # seek to 0 if video has finished. Infinite loop @@ -58,8 +61,8 @@ class FrameEmitter: else: prev_time = new_frame_time - + -def run_frame_emitter(config: Namespace): - router = FrameEmitter(config) +def run_frame_emitter(config: Namespace, is_running: Event): + router = FrameEmitter(config, is_running) router.emit_video() \ No newline at end of file diff --git a/trap/plumber.py b/trap/plumber.py index c0253b6..a33f487 100644 --- a/trap/plumber.py +++ b/trap/plumber.py @@ -1,19 +1,32 @@ import logging from logging.handlers import SocketHandler -from multiprocessing import Process, Queue +from multiprocessing import Event, Process, Queue +import sys from trap.config import parser from trap.frame_emitter import run_frame_emitter -from trap.prediction_server import InferenceServer, run_inference_server +from trap.prediction_server import run_prediction_server from trap.socket_forwarder import run_ws_forwarder from trap.tracker import run_tracker logger = logging.getLogger("trap.plumbing") +class ExceptionHandlingProcess(Process): + + def run(self): + assert 'is_running' in self._kwargs + try: + super(Process, self).run() + except Exception as e: + logger.exception(e) + self._kwargs['is_running'].clear() + + def start(): args = parser.parse_args() loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO - + # print(args) + # exit() logging.basicConfig( level=loglevel, ) @@ -21,6 +34,9 @@ def start(): # set per handler, so we can set it lower for the root logger if remote logging is enabled root_logger = logging.getLogger() [h.setLevel(loglevel) for h in root_logger.handlers] + + isRunning = Event() + isRunning.set() if args.remote_log_addr: @@ -29,15 +45,17 @@ def start(): socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port) root_logger.addHandler(socket_handler) + + # instantiating process with arguments procs = [ - Process(target=run_ws_forwarder, args=(args,), name='forwarder'), - Process(target=run_frame_emitter, args=(args,), name='frame_emitter'), - Process(target=run_tracker, args=(args,), name='tracker'), + ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'), + ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'), + ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'), ] if not args.bypass_prediction: procs.append( - Process(target=run_inference_server, args=(args,), name='inference'), + ExceptionHandlingProcess(target=run_prediction_server, kwargs={'config': args, 'is_running':isRunning}, name='inference'), ) logger.info("start") diff --git a/trap/prediction_server.py b/trap/prediction_server.py index 4144739..772c1ad 100644 --- a/trap/prediction_server.py +++ b/trap/prediction_server.py @@ -1,6 +1,7 @@ # adapted from Trajectron++ online_server.py +from argparse import Namespace import logging -from multiprocessing import Queue +from multiprocessing import Event, Queue import os import pickle import time @@ -108,9 +109,13 @@ def get_maps_for_input(input_dict, scene, hyperparams): return maps_dict -class InferenceServer: - def __init__(self, config: dict): +class PredictionServer: + def __init__(self, config: Namespace, is_running: Event): self.config = config + self.is_running = is_running + + if self.config.eval_device == 'cpu': + logger.warning("Running on CPU. Specifying --eval_device cuda:0 should dramatically speed up prediction") context = zmq.Context() self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB) @@ -199,7 +204,7 @@ class InferenceServer: trajectron.set_environment(online_env, init_timestep) timestep = init_timestep + 1 - while True: + while self.is_running.is_set(): timestep += 1 # for timestep in range(init_timestep + 1, eval_scene.timesteps): @@ -350,9 +355,19 @@ class InferenceServer: data = json.dumps(response) logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)") self.prediction_socket.send_string(data) + logger.info('Stopping') -def run_inference_server(config): - s = InferenceServer(config) +def run_prediction_server(config: Namespace, is_running: Event): + + # attempt to trace the warnings coming from pytorch + # def warn_with_traceback(message, category, filename, lineno, file=None, line=None): + + # log = file if hasattr(file,'write') else sys.stderr + # traceback.print_stack(file=log) + # log.write(warnings.formatwarning(message, category, filename, lineno, line)) + + # warnings.showwarning = warn_with_traceback + s = PredictionServer(config, is_running) s.run() \ No newline at end of file diff --git a/trap/socket_forwarder.py b/trap/socket_forwarder.py index aecee49..654128d 100644 --- a/trap/socket_forwarder.py +++ b/trap/socket_forwarder.py @@ -2,6 +2,7 @@ from argparse import Namespace import asyncio import logging +from multiprocessing import Event from typing import Set, Union, Dict, Any from typing_extensions import Self @@ -103,8 +104,9 @@ class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler): class WsRouter: - def __init__(self, config: Namespace): + def __init__(self, config: Namespace, is_running: Event): self.config = config + self.is_running = is_running context = zmq.asyncio.Context() self.trajectory_socket = context.socket(zmq.PUB) @@ -138,25 +140,31 @@ class WsRouter: def start(self): - evt_loop = asyncio.new_event_loop() - asyncio.set_event_loop(evt_loop) + self.evt_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.evt_loop) # loop = tornado.ioloop.IOLoop.current() logger.info(f"Listen on {self.config.ws_port}") self.application.listen(self.config.ws_port) loop = asyncio.get_event_loop() - task = evt_loop.create_task(self.prediction_forwarder()) + task = self.evt_loop.create_task(self.prediction_forwarder()) - evt_loop.run_forever() + self.evt_loop.run_forever() + async def prediction_forwarder(self): logger.info("Starting prediction forwarder") - while True: + while self.is_running.is_set(): msg = await self.prediction_socket.recv_string() logger.debug(f"Forward prediction message of {len(msg)} chars") WebSocketPredictionHandler.write_to_clients(msg) + + # die together: + self.evt_loop.stop() + logger.info('Stopping') + -def run_ws_forwarder(config: Namespace): - router = WsRouter(config) +def run_ws_forwarder(config: Namespace, is_running: Event): + router = WsRouter(config, is_running) router.start() \ No newline at end of file diff --git a/trap/tracker.py b/trap/tracker.py index dca8336..658e245 100644 --- a/trap/tracker.py +++ b/trap/tracker.py @@ -1,6 +1,7 @@ from argparse import Namespace import json import logging +from multiprocessing import Event import pickle import time import numpy as np @@ -20,8 +21,9 @@ Detections = [Detection] logger = logging.getLogger("trap.tracker") class Tracker: - def __init__(self, config: Namespace): + def __init__(self, config: Namespace, is_running: Event): self.config = config + self.is_running = is_running context = zmq.Context() self.frame_sock = context.socket(zmq.SUB) @@ -54,7 +56,7 @@ class Tracker: def track(self): - while True: + while self.is_running.is_set(): msg = self.frame_sock.recv() frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s # logger.info(f"Frame delivery delay = {time.time()-frame.time}s") @@ -89,6 +91,7 @@ class Tracker: #TODO calculate fps (also for other loops to see asynchonity) # fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display + logger.info('Stopping') def detect_persons(self, frame) -> Detections: @@ -128,6 +131,6 @@ class Tracker: return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections] -def run_tracker(config: Namespace): - router = Tracker(config) +def run_tracker(config: Namespace, is_running: Event): + router = Tracker(config, is_running) router.track() \ No newline at end of file