Gracefully stop all loops on crash of subprocess
This commit is contained in:
		
							parent
							
								
									9b39d7cd9b
								
							
						
					
					
						commit
						3d34263a71
					
				
					 5 changed files with 77 additions and 30 deletions
				
			
		| 
						 | 
					@ -1,7 +1,9 @@
 | 
				
			||||||
from argparse import Namespace
 | 
					from argparse import Namespace
 | 
				
			||||||
from dataclasses import dataclass, field
 | 
					from dataclasses import dataclass, field
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					from multiprocessing import Event
 | 
				
			||||||
import pickle
 | 
					import pickle
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
| 
						 | 
					@ -21,8 +23,9 @@ class FrameEmitter:
 | 
				
			||||||
    Emit frame in a separate threat so they can be throttled,
 | 
					    Emit frame in a separate threat so they can be throttled,
 | 
				
			||||||
    or thrown away when the rest of the system cannot keep up
 | 
					    or thrown away when the rest of the system cannot keep up
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    def __init__(self, config: Namespace) -> None:
 | 
					    def __init__(self, config: Namespace, is_running: Event) -> None:
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.is_running = is_running
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        context = zmq.Context()
 | 
					        context = zmq.Context()
 | 
				
			||||||
        self.frame_sock = context.socket(zmq.PUB)
 | 
					        self.frame_sock = context.socket(zmq.PUB)
 | 
				
			||||||
| 
						 | 
					@ -36,7 +39,7 @@ class FrameEmitter:
 | 
				
			||||||
        frame_duration = 1./fps
 | 
					        frame_duration = 1./fps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        prev_time = time.time()
 | 
					        prev_time = time.time()
 | 
				
			||||||
        while True:
 | 
					        while self.is_running.is_set():
 | 
				
			||||||
            ret, img = video.read()
 | 
					            ret, img = video.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # seek to 0 if video has finished. Infinite loop
 | 
					            # seek to 0 if video has finished. Infinite loop
 | 
				
			||||||
| 
						 | 
					@ -60,6 +63,6 @@ class FrameEmitter:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      
 | 
					      
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_frame_emitter(config: Namespace):
 | 
					def run_frame_emitter(config: Namespace, is_running: Event):
 | 
				
			||||||
    router = FrameEmitter(config)
 | 
					    router = FrameEmitter(config, is_running)
 | 
				
			||||||
    router.emit_video()
 | 
					    router.emit_video()
 | 
				
			||||||
| 
						 | 
					@ -1,19 +1,32 @@
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
from logging.handlers import SocketHandler
 | 
					from logging.handlers import SocketHandler
 | 
				
			||||||
from multiprocessing import Process, Queue
 | 
					from multiprocessing import Event, Process, Queue
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
from trap.config import parser
 | 
					from trap.config import parser
 | 
				
			||||||
from trap.frame_emitter import run_frame_emitter
 | 
					from trap.frame_emitter import run_frame_emitter
 | 
				
			||||||
from trap.prediction_server import InferenceServer, run_inference_server
 | 
					from trap.prediction_server import run_prediction_server
 | 
				
			||||||
from trap.socket_forwarder import run_ws_forwarder
 | 
					from trap.socket_forwarder import run_ws_forwarder
 | 
				
			||||||
from trap.tracker import run_tracker
 | 
					from trap.tracker import run_tracker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger("trap.plumbing")
 | 
					logger = logging.getLogger("trap.plumbing")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ExceptionHandlingProcess(Process):
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def run(self):
 | 
				
			||||||
 | 
					        assert 'is_running' in self._kwargs
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            super(Process, self).run()
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            logger.exception(e)
 | 
				
			||||||
 | 
					            self._kwargs['is_running'].clear()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def start():
 | 
					def start():
 | 
				
			||||||
    args = parser.parse_args()
 | 
					    args = parser.parse_args()
 | 
				
			||||||
    loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0  else logging.INFO
 | 
					    loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0  else logging.INFO
 | 
				
			||||||
    
 | 
					    # print(args)
 | 
				
			||||||
 | 
					    # exit()
 | 
				
			||||||
    logging.basicConfig(
 | 
					    logging.basicConfig(
 | 
				
			||||||
        level=loglevel,
 | 
					        level=loglevel,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -22,6 +35,9 @@ def start():
 | 
				
			||||||
    root_logger = logging.getLogger()
 | 
					    root_logger = logging.getLogger()
 | 
				
			||||||
    [h.setLevel(loglevel) for h in root_logger.handlers]
 | 
					    [h.setLevel(loglevel) for h in root_logger.handlers]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    isRunning = Event()
 | 
				
			||||||
 | 
					    isRunning.set()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if args.remote_log_addr:
 | 
					    if args.remote_log_addr:
 | 
				
			||||||
        logging.captureWarnings(True)
 | 
					        logging.captureWarnings(True)
 | 
				
			||||||
| 
						 | 
					@ -29,15 +45,17 @@ def start():
 | 
				
			||||||
        socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
 | 
					        socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
 | 
				
			||||||
        root_logger.addHandler(socket_handler)
 | 
					        root_logger.addHandler(socket_handler)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # instantiating process with arguments
 | 
					    # instantiating process with arguments
 | 
				
			||||||
    procs = [
 | 
					    procs = [
 | 
				
			||||||
        Process(target=run_ws_forwarder, args=(args,), name='forwarder'),
 | 
					        ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
 | 
				
			||||||
        Process(target=run_frame_emitter, args=(args,), name='frame_emitter'),
 | 
					        ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
 | 
				
			||||||
        Process(target=run_tracker, args=(args,), name='tracker'),
 | 
					        ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
    if not args.bypass_prediction:
 | 
					    if not args.bypass_prediction:
 | 
				
			||||||
        procs.append(
 | 
					        procs.append(
 | 
				
			||||||
            Process(target=run_inference_server, args=(args,), name='inference'),
 | 
					            ExceptionHandlingProcess(target=run_prediction_server, kwargs={'config': args, 'is_running':isRunning}, name='inference'),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("start")
 | 
					    logger.info("start")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
# adapted from Trajectron++ online_server.py
 | 
					# adapted from Trajectron++ online_server.py
 | 
				
			||||||
 | 
					from argparse import Namespace
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
from multiprocessing import Queue
 | 
					from multiprocessing import Event, Queue
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import pickle
 | 
					import pickle
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
| 
						 | 
					@ -108,9 +109,13 @@ def get_maps_for_input(input_dict, scene, hyperparams):
 | 
				
			||||||
    return maps_dict
 | 
					    return maps_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class InferenceServer:
 | 
					class PredictionServer:
 | 
				
			||||||
    def __init__(self, config: dict):
 | 
					    def __init__(self, config: Namespace, is_running: Event):
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.is_running = is_running
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.config.eval_device == 'cpu':
 | 
				
			||||||
 | 
					            logger.warning("Running on CPU. Specifying --eval_device cuda:0 should dramatically speed up prediction")
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        context = zmq.Context()
 | 
					        context = zmq.Context()
 | 
				
			||||||
        self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
 | 
					        self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
 | 
				
			||||||
| 
						 | 
					@ -199,7 +204,7 @@ class InferenceServer:
 | 
				
			||||||
        trajectron.set_environment(online_env, init_timestep)
 | 
					        trajectron.set_environment(online_env, init_timestep)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        timestep = init_timestep + 1
 | 
					        timestep = init_timestep + 1
 | 
				
			||||||
        while True:
 | 
					        while self.is_running.is_set():
 | 
				
			||||||
            timestep += 1
 | 
					            timestep += 1
 | 
				
			||||||
            # for timestep in range(init_timestep + 1, eval_scene.timesteps):
 | 
					            # for timestep in range(init_timestep + 1, eval_scene.timesteps):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -350,9 +355,19 @@ class InferenceServer:
 | 
				
			||||||
            data = json.dumps(response)
 | 
					            data = json.dumps(response)
 | 
				
			||||||
            logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
 | 
					            logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
 | 
				
			||||||
            self.prediction_socket.send_string(data)
 | 
					            self.prediction_socket.send_string(data)
 | 
				
			||||||
 | 
					        logger.info('Stopping')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_inference_server(config):
 | 
					def run_prediction_server(config: Namespace, is_running: Event):
 | 
				
			||||||
    s = InferenceServer(config)
 | 
					
 | 
				
			||||||
 | 
					    # attempt to trace the warnings coming from pytorch
 | 
				
			||||||
 | 
					    # def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     log = file if hasattr(file,'write') else sys.stderr
 | 
				
			||||||
 | 
					    #     traceback.print_stack(file=log)
 | 
				
			||||||
 | 
					    #     log.write(warnings.formatwarning(message, category, filename, lineno, line))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # warnings.showwarning = warn_with_traceback
 | 
				
			||||||
 | 
					    s = PredictionServer(config, is_running)
 | 
				
			||||||
    s.run()
 | 
					    s.run()
 | 
				
			||||||
| 
						 | 
					@ -2,6 +2,7 @@
 | 
				
			||||||
from argparse import Namespace
 | 
					from argparse import Namespace
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					from multiprocessing import Event
 | 
				
			||||||
from typing import Set, Union, Dict, Any
 | 
					from typing import Set, Union, Dict, Any
 | 
				
			||||||
from typing_extensions import Self
 | 
					from typing_extensions import Self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -103,8 +104,9 @@ class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class WsRouter:
 | 
					class WsRouter:
 | 
				
			||||||
    def __init__(self, config: Namespace):
 | 
					    def __init__(self, config: Namespace, is_running: Event):
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.is_running = is_running
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        context = zmq.asyncio.Context()
 | 
					        context = zmq.asyncio.Context()
 | 
				
			||||||
        self.trajectory_socket = context.socket(zmq.PUB)
 | 
					        self.trajectory_socket = context.socket(zmq.PUB)
 | 
				
			||||||
| 
						 | 
					@ -138,25 +140,31 @@ class WsRouter:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def start(self):
 | 
					    def start(self):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        evt_loop = asyncio.new_event_loop()
 | 
					        self.evt_loop = asyncio.new_event_loop()
 | 
				
			||||||
        asyncio.set_event_loop(evt_loop)
 | 
					        asyncio.set_event_loop(self.evt_loop)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # loop = tornado.ioloop.IOLoop.current()
 | 
					        # loop = tornado.ioloop.IOLoop.current()
 | 
				
			||||||
        logger.info(f"Listen on {self.config.ws_port}")
 | 
					        logger.info(f"Listen on {self.config.ws_port}")
 | 
				
			||||||
        self.application.listen(self.config.ws_port)
 | 
					        self.application.listen(self.config.ws_port)
 | 
				
			||||||
        loop = asyncio.get_event_loop()
 | 
					        loop = asyncio.get_event_loop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        task = evt_loop.create_task(self.prediction_forwarder())
 | 
					        task = self.evt_loop.create_task(self.prediction_forwarder())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.evt_loop.run_forever()
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        evt_loop.run_forever()
 | 
					 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    async def prediction_forwarder(self):
 | 
					    async def prediction_forwarder(self):
 | 
				
			||||||
        logger.info("Starting prediction forwarder")
 | 
					        logger.info("Starting prediction forwarder")
 | 
				
			||||||
        while True:
 | 
					        while self.is_running.is_set():
 | 
				
			||||||
            msg = await self.prediction_socket.recv_string()
 | 
					            msg = await self.prediction_socket.recv_string()
 | 
				
			||||||
            logger.debug(f"Forward prediction message of {len(msg)} chars")
 | 
					            logger.debug(f"Forward prediction message of {len(msg)} chars")
 | 
				
			||||||
            WebSocketPredictionHandler.write_to_clients(msg)
 | 
					            WebSocketPredictionHandler.write_to_clients(msg)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
def run_ws_forwarder(config: Namespace):
 | 
					        # die together:
 | 
				
			||||||
    router = WsRouter(config)
 | 
					        self.evt_loop.stop()
 | 
				
			||||||
 | 
					        logger.info('Stopping')
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def run_ws_forwarder(config: Namespace, is_running: Event):
 | 
				
			||||||
 | 
					    router = WsRouter(config, is_running)
 | 
				
			||||||
    router.start()
 | 
					    router.start()
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
from argparse import Namespace
 | 
					from argparse import Namespace
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					from multiprocessing import Event
 | 
				
			||||||
import pickle
 | 
					import pickle
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
| 
						 | 
					@ -20,8 +21,9 @@ Detections = [Detection]
 | 
				
			||||||
logger = logging.getLogger("trap.tracker")
 | 
					logger = logging.getLogger("trap.tracker")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Tracker:
 | 
					class Tracker:
 | 
				
			||||||
    def __init__(self, config: Namespace):
 | 
					    def __init__(self, config: Namespace, is_running: Event):
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
 | 
					        self.is_running = is_running
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        context = zmq.Context()
 | 
					        context = zmq.Context()
 | 
				
			||||||
        self.frame_sock = context.socket(zmq.SUB)
 | 
					        self.frame_sock = context.socket(zmq.SUB)
 | 
				
			||||||
| 
						 | 
					@ -54,7 +56,7 @@ class Tracker:
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def track(self):
 | 
					    def track(self):
 | 
				
			||||||
        while True:
 | 
					        while self.is_running.is_set():
 | 
				
			||||||
            msg = self.frame_sock.recv()
 | 
					            msg = self.frame_sock.recv()
 | 
				
			||||||
            frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
 | 
					            frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
 | 
				
			||||||
            # logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
 | 
					            # logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
 | 
				
			||||||
| 
						 | 
					@ -89,6 +91,7 @@ class Tracker:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            #TODO calculate fps (also for other loops to see asynchonity)
 | 
					            #TODO calculate fps (also for other loops to see asynchonity)
 | 
				
			||||||
            # fpsfilter=fpsfilter*.9+(1/dt)*.1    #trust value in order to stabilize fps display
 | 
					            # fpsfilter=fpsfilter*.9+(1/dt)*.1    #trust value in order to stabilize fps display
 | 
				
			||||||
 | 
					        logger.info('Stopping')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def detect_persons(self, frame) -> Detections:
 | 
					    def detect_persons(self, frame) -> Detections:
 | 
				
			||||||
| 
						 | 
					@ -128,6 +131,6 @@ class Tracker:
 | 
				
			||||||
        return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
 | 
					        return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_tracker(config: Namespace):
 | 
					def run_tracker(config: Namespace, is_running: Event):
 | 
				
			||||||
    router = Tracker(config)
 | 
					    router = Tracker(config, is_running)
 | 
				
			||||||
    router.track()
 | 
					    router.track()
 | 
				
			||||||
		Loading…
	
		Reference in a new issue