Compare commits

...

4 Commits

Author SHA1 Message Date
Ruben van de Ven 821d06c9cf Test renderer to frames 2023-10-17 09:49:20 +02:00
Ruben van de Ven 3d34263a71 Gracefully stop all loops on crash of subprocess 2023-10-17 09:45:37 +02:00
Ruben van de Ven 9b39d7cd9b Logging tweaks 2023-10-14 13:47:07 +02:00
Ruben van de Ven 27565d919e Fix zmq-conflate and measure timings 2023-10-13 23:13:12 +02:00
7 changed files with 234 additions and 51 deletions

View File

@ -1,6 +1,8 @@
import argparse
from pathlib import Path
from pyparsing import Optional
parser = argparse.ArgumentParser()
@ -27,6 +29,7 @@ inference_parser = parser.add_argument_group('Inference')
connection_parser = parser.add_argument_group('Connection')
frame_emitter_parser = parser.add_argument_group('Frame emitter')
tracker_parser = parser.add_argument_group('Tracker')
render_parser = parser.add_argument_group('Renderer')
inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference",
@ -106,7 +109,7 @@ inference_parser.add_argument("--eval_data_dict",
inference_parser.add_argument("--output_dir",
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
type=str,
type=Path,
default='./OUT/test_inference')
@ -174,3 +177,11 @@ tracker_parser.add_argument("--homography",
help="File with homography params",
type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
# Renderer
# render_parser.add_argument("--output-dir",
# help="Target image dir",
# type=Optional[Path],
# default=None)

View File

@ -1,25 +1,36 @@
from argparse import Namespace
from dataclasses import dataclass, field
import logging
from multiprocessing import Event
import pickle
import sys
import time
from typing import Optional
import numpy as np
import cv2
import zmq
logger = logging.getLogger('trap.frame_emitter')
@dataclass
class Frame:
img: np.array
time: float= field(default_factory=lambda: time.time())
trajectories: Optional[dict] = None
class FrameEmitter:
'''
Emit frame in a separate threat so they can be throttled,
or thrown away when the rest of the system cannot keep up
'''
def __init__(self, config: Namespace) -> None:
def __init__(self, config: Namespace, is_running: Event) -> None:
self.config = config
self.is_running = is_running
context = zmq.Context()
self.frame_sock = context.socket(zmq.PUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
self.frame_sock.bind(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
logger.info(f"Connection socket {config.zmq_frame_addr}")
def emit_video(self):
@ -28,15 +39,15 @@ class FrameEmitter:
frame_duration = 1./fps
prev_time = time.time()
while True:
ret, frame = video.read()
while self.is_running.is_set():
ret, img = video.read()
# seek to 0 if video has finished. Infinite loop
if not ret:
video.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = video.read()
ret, img = video.read()
assert ret is not False # not really error proof...
frame = Frame(img=img)
# TODO: this is very dirty, need to find another way.
# perhaps multiprocessing queue?
self.frame_sock.send(pickle.dumps(frame))
@ -50,8 +61,8 @@ class FrameEmitter:
else:
prev_time = new_frame_time
def run_frame_emitter(config: Namespace):
router = FrameEmitter(config)
def run_frame_emitter(config: Namespace, is_running: Event):
router = FrameEmitter(config, is_running)
router.emit_video()

View File

@ -1,38 +1,63 @@
import logging
from logging.handlers import SocketHandler
from multiprocessing import Process, Queue
from multiprocessing import Event, Process, Queue
import sys
from trap.config import parser
from trap.frame_emitter import run_frame_emitter
from trap.prediction_server import InferenceServer, run_inference_server
from trap.prediction_server import run_prediction_server
from trap.renderer import run_renderer
from trap.socket_forwarder import run_ws_forwarder
from trap.tracker import run_tracker
logger = logging.getLogger("trap.plumbing")
class ExceptionHandlingProcess(Process):
def run(self):
assert 'is_running' in self._kwargs
try:
super(Process, self).run()
except Exception as e:
logger.exception(e)
self._kwargs['is_running'].clear()
def start():
args = parser.parse_args()
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
# print(args)
# exit()
logging.basicConfig(
level=loglevel,
)
# set per handler, so we can set it lower for the root logger if remote logging is enabled
root_logger = logging.getLogger()
[h.setLevel(loglevel) for h in root_logger.handlers]
isRunning = Event()
isRunning.set()
if args.remote_log_addr:
logging.captureWarnings(True)
root_logger = logging.getLogger()
root_logger.setLevel(logging.NOTSET) # to send all records to cutelog
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
root_logger.addHandler(socket_handler)
# instantiating process with arguments
procs = [
Process(target=run_ws_forwarder, args=(args,), name='forwarder'),
Process(target=run_frame_emitter, args=(args,), name='frame_emitter'),
Process(target=run_tracker, args=(args,), name='tracker'),
ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer'),
]
if not args.bypass_prediction:
procs.append(
Process(target=run_inference_server, args=(args,), name='inference'),
ExceptionHandlingProcess(target=run_prediction_server, kwargs={'config': args, 'is_running':isRunning}, name='inference'),
)
logger.info("start")

View File

@ -1,9 +1,14 @@
# adapted from Trajectron++ online_server.py
from argparse import Namespace
import logging
from multiprocessing import Queue
from multiprocessing import Event, Queue
import os
import pickle
import sys
import time
import json
import traceback
import warnings
import pandas as pd
import torch
import dill
@ -21,6 +26,8 @@ import matplotlib.pyplot as plt
import zmq
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.prediction")
@ -105,15 +112,19 @@ def get_maps_for_input(input_dict, scene, hyperparams):
return maps_dict
class InferenceServer:
def __init__(self, config: dict):
class PredictionServer:
def __init__(self, config: Namespace, is_running: Event):
self.config = config
self.is_running = is_running
if self.config.eval_device == 'cpu':
logger.warning("Running on CPU. Specifying --eval_device cuda:0 should dramatically speed up prediction")
context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr)
@ -196,7 +207,7 @@ class InferenceServer:
trajectron.set_environment(online_env, init_timestep)
timestep = init_timestep + 1
while True:
while self.is_running.is_set():
timestep += 1
# for timestep in range(init_timestep + 1, eval_scene.timesteps):
@ -220,8 +231,10 @@ class InferenceServer:
# node_data = pd.DataFrame(data_dict, columns=data_columns)
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
data = self.trajectory_socket.recv_string()
trajectory_data = json.loads(data)
data = self.trajectory_socket.recv()
frame: Frame = pickle.loads(data)
trajectory_data = frame.trajectories # TODO: properly refractor
# trajectory_data = json.loads(data)
logger.debug(f"Receive {trajectory_data}")
# class FakeNode:
@ -291,12 +304,12 @@ class InferenceServer:
start = time.time()
dists, preds = trajectron.incremental_forward(input_dict,
maps,
prediction_horizon=16, # TODO: make variable
num_samples=3, # TODO: make variable
prediction_horizon=20, # TODO: make variable
num_samples=2, # TODO: make variable
robot_present_and_future=robot_present_and_future,
full_dist=True)
end = time.time()
logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start,
1. / (end - start), len(trajectron.nodes),
trajectron.scene_graph.get_num_edges()))
@ -343,10 +356,21 @@ class InferenceServer:
}
data = json.dumps(response)
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
self.prediction_socket.send_string(data)
logger.info('Stopping')
def run_inference_server(config):
s = InferenceServer(config)
def run_prediction_server(config: Namespace, is_running: Event):
# attempt to trace the warnings coming from pytorch
# def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
# log = file if hasattr(file,'write') else sys.stderr
# traceback.print_stack(file=log)
# log.write(warnings.formatwarning(message, category, filename, lineno, line))
# warnings.showwarning = warn_with_traceback
s = PredictionServer(config, is_running)
s.run()

91
trap/renderer.py Normal file
View File

@ -0,0 +1,91 @@
from argparse import Namespace
import logging
from multiprocessing import Event
import cv2
import numpy as np
import zmq
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.renderer")
class Renderer:
def __init__(self, config: Namespace, is_running: Event):
self.config = config
self.is_running = is_running
context = zmq.Context()
self.prediction_sock = context.socket(zmq.SUB)
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_sock.connect(config.zmq_prediction_addr)
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.connect(config.zmq_frame_addr)
H = np.loadtxt(self.config.homography, delimiter=',')
self.inv_H = np.linalg.pinv(H)
if not self.config.output_dir.exists():
raise FileNotFoundError("Path does not exist")
def run(self):
predictions = {}
i=0
first_time = None
while self.is_running.is_set():
i+=1
frame: Frame = self.frame_sock.recv_pyobj()
try:
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
except zmq.ZMQError as e:
logger.debug(f'reuse prediction')
img = frame.img
for track_id, prediction in predictions.items():
if not 'history' in prediction or not len(prediction['history']):
continue
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
# logger.warning(f"{coords=}")
center = [int(p) for p in coords[-1]]
cv2.circle(img, center, 5, (0,255,0))
for ci in range(1, len(coords)):
start = [int(p) for p in coords[ci-1]]
end = [int(p) for p in coords[ci]]
cv2.line(img, start, end, (255,255,255), 2)
if not 'predictions' in prediction or not len(prediction['predictions']):
continue
for pred in prediction['predictions']:
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
for ci in range(1, len(pred_coords)):
start = [int(p) for p in pred_coords[ci-1]]
end = [int(p) for p in pred_coords[ci]]
cv2.line(img, start, end, (0,0,255), 2)
if first_time is None:
first_time = frame.time
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
cv2.imwrite(str(img_path), img)
logger.info('Stopping')
def run_renderer(config: Namespace, is_running: Event):
renderer = Renderer(config, is_running)
renderer.run()

View File

@ -2,6 +2,7 @@
from argparse import Namespace
import asyncio
import logging
from multiprocessing import Event
from typing import Set, Union, Dict, Any
from typing_extensions import Self
@ -103,8 +104,9 @@ class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
class WsRouter:
def __init__(self, config: Namespace):
def __init__(self, config: Namespace, is_running: Event):
self.config = config
self.is_running = is_running
context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB)
@ -138,25 +140,31 @@ class WsRouter:
def start(self):
evt_loop = asyncio.new_event_loop()
asyncio.set_event_loop(evt_loop)
self.evt_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.evt_loop)
# loop = tornado.ioloop.IOLoop.current()
logger.info(f"Listen on {self.config.ws_port}")
self.application.listen(self.config.ws_port)
loop = asyncio.get_event_loop()
task = evt_loop.create_task(self.prediction_forwarder())
task = self.evt_loop.create_task(self.prediction_forwarder())
evt_loop.run_forever()
self.evt_loop.run_forever()
async def prediction_forwarder(self):
logger.info("Starting prediction forwarder")
while True:
while self.is_running.is_set():
msg = await self.prediction_socket.recv_string()
logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg)
# die together:
self.evt_loop.stop()
logger.info('Stopping')
def run_ws_forwarder(config: Namespace):
router = WsRouter(config)
def run_ws_forwarder(config: Namespace, is_running: Event):
router = WsRouter(config, is_running)
router.start()

View File

@ -1,6 +1,7 @@
from argparse import Namespace
import json
import logging
from multiprocessing import Event
import pickle
import time
import numpy as np
@ -12,24 +13,27 @@ from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_Re
from deep_sort_realtime.deepsort_tracker import DeepSort
from deep_sort_realtime.deep_sort.track import Track
from trap.frame_emitter import Frame
Detection = [int, int, int, int, float, int]
Detections = [Detection]
logger = logging.getLogger("trap.tracker")
class Tracker:
def __init__(self, config: Namespace):
def __init__(self, config: Namespace, is_running: Event):
self.config = config
self.is_running = is_running
context = zmq.Context()
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.connect(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.frame_sock.connect(config.zmq_frame_addr)
self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.trajectory_socket.bind(config.zmq_trajectory_addr)
# # TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -52,10 +56,13 @@ class Tracker:
def track(self):
while True:
frame = pickle.loads(self.frame_sock.recv())
detections = self.detect_persons(frame)
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame)
while self.is_running.is_set():
msg = self.frame_sock.recv()
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
start_time = time.time()
detections = self.detect_persons(frame.img)
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
TEMP_boxes = [t.to_ltwh() for t in tracks]
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
@ -72,13 +79,19 @@ class Tracker:
"id": tid,
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
}
logger.debug(f"{trajectories}")
self.trajectory_socket.send_string(json.dumps(trajectories))
# logger.debug(f"{trajectories}")
frame.trajectories = trajectories
current_time = time.time()
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
self.trajectory_socket.send(pickle.dumps(frame))
# self.trajectory_socket.send_string(json.dumps(trajectories))
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
# TODO: provide a track object that actually keeps history (unlike tracker)
#TODO calculate fps (also for other loops to see asynchonity)
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
logger.info('Stopping')
def detect_persons(self, frame) -> Detections:
@ -118,6 +131,6 @@ class Tracker:
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
def run_tracker(config: Namespace):
router = Tracker(config)
def run_tracker(config: Namespace, is_running: Event):
router = Tracker(config, is_running)
router.track()