Compare commits

..

No commits in common. "821d06c9cf6126e3c1080c6702a68b5af1694388" and "23162da76796a21222cef5f4e70c7d9e296696c3" have entirely different histories.

7 changed files with 51 additions and 234 deletions

View file

@ -1,8 +1,6 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from pyparsing import Optional
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -29,7 +27,6 @@ inference_parser = parser.add_argument_group('Inference')
connection_parser = parser.add_argument_group('Connection') connection_parser = parser.add_argument_group('Connection')
frame_emitter_parser = parser.add_argument_group('Frame emitter') frame_emitter_parser = parser.add_argument_group('Frame emitter')
tracker_parser = parser.add_argument_group('Tracker') tracker_parser = parser.add_argument_group('Tracker')
render_parser = parser.add_argument_group('Renderer')
inference_parser.add_argument("--model_dir", inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference", help="directory with the model to use for inference",
@ -109,7 +106,7 @@ inference_parser.add_argument("--eval_data_dict",
inference_parser.add_argument("--output_dir", inference_parser.add_argument("--output_dir",
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)", help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
type=Path, type=str,
default='./OUT/test_inference') default='./OUT/test_inference')
@ -177,11 +174,3 @@ tracker_parser.add_argument("--homography",
help="File with homography params", help="File with homography params",
type=Path, type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt') default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
# Renderer
# render_parser.add_argument("--output-dir",
# help="Target image dir",
# type=Optional[Path],
# default=None)

View file

@ -1,36 +1,25 @@
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass, field
import logging import logging
from multiprocessing import Event
import pickle import pickle
import sys
import time import time
from typing import Optional
import numpy as np
import cv2 import cv2
import zmq import zmq
logger = logging.getLogger('trap.frame_emitter') logger = logging.getLogger('trap.frame_emitter')
@dataclass
class Frame:
img: np.array
time: float= field(default_factory=lambda: time.time())
trajectories: Optional[dict] = None
class FrameEmitter: class FrameEmitter:
''' '''
Emit frame in a separate threat so they can be throttled, Emit frame in a separate threat so they can be throttled,
or thrown away when the rest of the system cannot keep up or thrown away when the rest of the system cannot keep up
''' '''
def __init__(self, config: Namespace, is_running: Event) -> None: def __init__(self, config: Namespace) -> None:
self.config = config self.config = config
self.is_running = is_running
context = zmq.Context() context = zmq.Context()
self.frame_sock = context.socket(zmq.PUB) self.frame_sock = context.socket(zmq.PUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
self.frame_sock.bind(config.zmq_frame_addr) self.frame_sock.bind(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
logger.info(f"Connection socket {config.zmq_frame_addr}") logger.info(f"Connection socket {config.zmq_frame_addr}")
def emit_video(self): def emit_video(self):
@ -39,15 +28,15 @@ class FrameEmitter:
frame_duration = 1./fps frame_duration = 1./fps
prev_time = time.time() prev_time = time.time()
while self.is_running.is_set(): while True:
ret, img = video.read() ret, frame = video.read()
# seek to 0 if video has finished. Infinite loop # seek to 0 if video has finished. Infinite loop
if not ret: if not ret:
video.set(cv2.CAP_PROP_POS_FRAMES, 0) video.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, img = video.read() ret, frame = video.read()
assert ret is not False # not really error proof... assert ret is not False # not really error proof...
frame = Frame(img=img)
# TODO: this is very dirty, need to find another way. # TODO: this is very dirty, need to find another way.
# perhaps multiprocessing queue? # perhaps multiprocessing queue?
self.frame_sock.send(pickle.dumps(frame)) self.frame_sock.send(pickle.dumps(frame))
@ -63,6 +52,6 @@ class FrameEmitter:
def run_frame_emitter(config: Namespace, is_running: Event): def run_frame_emitter(config: Namespace):
router = FrameEmitter(config, is_running) router = FrameEmitter(config)
router.emit_video() router.emit_video()

View file

@ -1,63 +1,38 @@
import logging import logging
from logging.handlers import SocketHandler from logging.handlers import SocketHandler
from multiprocessing import Event, Process, Queue from multiprocessing import Process, Queue
import sys
from trap.config import parser from trap.config import parser
from trap.frame_emitter import run_frame_emitter from trap.frame_emitter import run_frame_emitter
from trap.prediction_server import run_prediction_server from trap.prediction_server import InferenceServer, run_inference_server
from trap.renderer import run_renderer
from trap.socket_forwarder import run_ws_forwarder from trap.socket_forwarder import run_ws_forwarder
from trap.tracker import run_tracker from trap.tracker import run_tracker
logger = logging.getLogger("trap.plumbing") logger = logging.getLogger("trap.plumbing")
class ExceptionHandlingProcess(Process):
def run(self):
assert 'is_running' in self._kwargs
try:
super(Process, self).run()
except Exception as e:
logger.exception(e)
self._kwargs['is_running'].clear()
def start(): def start():
args = parser.parse_args() args = parser.parse_args()
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
# print(args)
# exit()
logging.basicConfig( logging.basicConfig(
level=loglevel, level=loglevel,
) )
# set per handler, so we can set it lower for the root logger if remote logging is enabled
root_logger = logging.getLogger()
[h.setLevel(loglevel) for h in root_logger.handlers]
isRunning = Event()
isRunning.set()
if args.remote_log_addr: if args.remote_log_addr:
logging.captureWarnings(True) logging.captureWarnings(True)
root_logger = logging.getLogger()
root_logger.setLevel(logging.NOTSET) # to send all records to cutelog root_logger.setLevel(logging.NOTSET) # to send all records to cutelog
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port) socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
root_logger.addHandler(socket_handler) root_logger.addHandler(socket_handler)
# instantiating process with arguments # instantiating process with arguments
procs = [ procs = [
ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'), Process(target=run_ws_forwarder, args=(args,), name='forwarder'),
ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'), Process(target=run_frame_emitter, args=(args,), name='frame_emitter'),
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'), Process(target=run_tracker, args=(args,), name='tracker'),
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer'),
] ]
if not args.bypass_prediction: if not args.bypass_prediction:
procs.append( procs.append(
ExceptionHandlingProcess(target=run_prediction_server, kwargs={'config': args, 'is_running':isRunning}, name='inference'), Process(target=run_inference_server, args=(args,), name='inference'),
) )
logger.info("start") logger.info("start")

View file

@ -1,14 +1,9 @@
# adapted from Trajectron++ online_server.py # adapted from Trajectron++ online_server.py
from argparse import Namespace
import logging import logging
from multiprocessing import Event, Queue from multiprocessing import Queue
import os import os
import pickle
import sys
import time import time
import json import json
import traceback
import warnings
import pandas as pd import pandas as pd
import torch import torch
import dill import dill
@ -26,8 +21,6 @@ import matplotlib.pyplot as plt
import zmq import zmq
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.prediction") logger = logging.getLogger("trap.prediction")
@ -112,19 +105,15 @@ def get_maps_for_input(input_dict, scene, hyperparams):
return maps_dict return maps_dict
class PredictionServer: class InferenceServer:
def __init__(self, config: Namespace, is_running: Event): def __init__(self, config: dict):
self.config = config self.config = config
self.is_running = is_running
if self.config.eval_device == 'cpu':
logger.warning("Running on CPU. Specifying --eval_device cuda:0 should dramatically speed up prediction")
context = zmq.Context() context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB) self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
self.trajectory_socket.connect(config.zmq_trajectory_addr) self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB) self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr) self.prediction_socket.bind(config.zmq_prediction_addr)
@ -207,7 +196,7 @@ class PredictionServer:
trajectron.set_environment(online_env, init_timestep) trajectron.set_environment(online_env, init_timestep)
timestep = init_timestep + 1 timestep = init_timestep + 1
while self.is_running.is_set(): while True:
timestep += 1 timestep += 1
# for timestep in range(init_timestep + 1, eval_scene.timesteps): # for timestep in range(init_timestep + 1, eval_scene.timesteps):
@ -231,10 +220,8 @@ class PredictionServer:
# node_data = pd.DataFrame(data_dict, columns=data_columns) # node_data = pd.DataFrame(data_dict, columns=data_columns)
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data) # node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
data = self.trajectory_socket.recv() data = self.trajectory_socket.recv_string()
frame: Frame = pickle.loads(data) trajectory_data = json.loads(data)
trajectory_data = frame.trajectories # TODO: properly refractor
# trajectory_data = json.loads(data)
logger.debug(f"Receive {trajectory_data}") logger.debug(f"Receive {trajectory_data}")
# class FakeNode: # class FakeNode:
@ -304,12 +291,12 @@ class PredictionServer:
start = time.time() start = time.time()
dists, preds = trajectron.incremental_forward(input_dict, dists, preds = trajectron.incremental_forward(input_dict,
maps, maps,
prediction_horizon=20, # TODO: make variable prediction_horizon=16, # TODO: make variable
num_samples=2, # TODO: make variable num_samples=3, # TODO: make variable
robot_present_and_future=robot_present_and_future, robot_present_and_future=robot_present_and_future,
full_dist=True) full_dist=True)
end = time.time() end = time.time()
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start, logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
1. / (end - start), len(trajectron.nodes), 1. / (end - start), len(trajectron.nodes),
trajectron.scene_graph.get_num_edges())) trajectron.scene_graph.get_num_edges()))
@ -356,21 +343,10 @@ class PredictionServer:
} }
data = json.dumps(response) data = json.dumps(response)
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
self.prediction_socket.send_string(data) self.prediction_socket.send_string(data)
logger.info('Stopping')
def run_prediction_server(config: Namespace, is_running: Event): def run_inference_server(config):
s = InferenceServer(config)
# attempt to trace the warnings coming from pytorch
# def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
# log = file if hasattr(file,'write') else sys.stderr
# traceback.print_stack(file=log)
# log.write(warnings.formatwarning(message, category, filename, lineno, line))
# warnings.showwarning = warn_with_traceback
s = PredictionServer(config, is_running)
s.run() s.run()

View file

@ -1,91 +0,0 @@
from argparse import Namespace
import logging
from multiprocessing import Event
import cv2
import numpy as np
import zmq
from trap.frame_emitter import Frame
logger = logging.getLogger("trap.renderer")
class Renderer:
def __init__(self, config: Namespace, is_running: Event):
self.config = config
self.is_running = is_running
context = zmq.Context()
self.prediction_sock = context.socket(zmq.SUB)
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_sock.connect(config.zmq_prediction_addr)
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.connect(config.zmq_frame_addr)
H = np.loadtxt(self.config.homography, delimiter=',')
self.inv_H = np.linalg.pinv(H)
if not self.config.output_dir.exists():
raise FileNotFoundError("Path does not exist")
def run(self):
predictions = {}
i=0
first_time = None
while self.is_running.is_set():
i+=1
frame: Frame = self.frame_sock.recv_pyobj()
try:
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
except zmq.ZMQError as e:
logger.debug(f'reuse prediction')
img = frame.img
for track_id, prediction in predictions.items():
if not 'history' in prediction or not len(prediction['history']):
continue
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
# logger.warning(f"{coords=}")
center = [int(p) for p in coords[-1]]
cv2.circle(img, center, 5, (0,255,0))
for ci in range(1, len(coords)):
start = [int(p) for p in coords[ci-1]]
end = [int(p) for p in coords[ci]]
cv2.line(img, start, end, (255,255,255), 2)
if not 'predictions' in prediction or not len(prediction['predictions']):
continue
for pred in prediction['predictions']:
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
for ci in range(1, len(pred_coords)):
start = [int(p) for p in pred_coords[ci-1]]
end = [int(p) for p in pred_coords[ci]]
cv2.line(img, start, end, (0,0,255), 2)
if first_time is None:
first_time = frame.time
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
cv2.imwrite(str(img_path), img)
logger.info('Stopping')
def run_renderer(config: Namespace, is_running: Event):
renderer = Renderer(config, is_running)
renderer.run()

View file

@ -2,7 +2,6 @@
from argparse import Namespace from argparse import Namespace
import asyncio import asyncio
import logging import logging
from multiprocessing import Event
from typing import Set, Union, Dict, Any from typing import Set, Union, Dict, Any
from typing_extensions import Self from typing_extensions import Self
@ -104,9 +103,8 @@ class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
class WsRouter: class WsRouter:
def __init__(self, config: Namespace, is_running: Event): def __init__(self, config: Namespace):
self.config = config self.config = config
self.is_running = is_running
context = zmq.asyncio.Context() context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB) self.trajectory_socket = context.socket(zmq.PUB)
@ -140,31 +138,25 @@ class WsRouter:
def start(self): def start(self):
self.evt_loop = asyncio.new_event_loop() evt_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.evt_loop) asyncio.set_event_loop(evt_loop)
# loop = tornado.ioloop.IOLoop.current() # loop = tornado.ioloop.IOLoop.current()
logger.info(f"Listen on {self.config.ws_port}") logger.info(f"Listen on {self.config.ws_port}")
self.application.listen(self.config.ws_port) self.application.listen(self.config.ws_port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
task = self.evt_loop.create_task(self.prediction_forwarder()) task = evt_loop.create_task(self.prediction_forwarder())
self.evt_loop.run_forever()
evt_loop.run_forever()
async def prediction_forwarder(self): async def prediction_forwarder(self):
logger.info("Starting prediction forwarder") logger.info("Starting prediction forwarder")
while self.is_running.is_set(): while True:
msg = await self.prediction_socket.recv_string() msg = await self.prediction_socket.recv_string()
logger.debug(f"Forward prediction message of {len(msg)} chars") logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg) WebSocketPredictionHandler.write_to_clients(msg)
# die together: def run_ws_forwarder(config: Namespace):
self.evt_loop.stop() router = WsRouter(config)
logger.info('Stopping')
def run_ws_forwarder(config: Namespace, is_running: Event):
router = WsRouter(config, is_running)
router.start() router.start()

View file

@ -1,7 +1,6 @@
from argparse import Namespace from argparse import Namespace
import json import json
import logging import logging
from multiprocessing import Event
import pickle import pickle
import time import time
import numpy as np import numpy as np
@ -13,27 +12,24 @@ from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_Re
from deep_sort_realtime.deepsort_tracker import DeepSort from deep_sort_realtime.deepsort_tracker import DeepSort
from deep_sort_realtime.deep_sort.track import Track from deep_sort_realtime.deep_sort.track import Track
from trap.frame_emitter import Frame
Detection = [int, int, int, int, float, int] Detection = [int, int, int, int, float, int]
Detections = [Detection] Detections = [Detection]
logger = logging.getLogger("trap.tracker") logger = logging.getLogger("trap.tracker")
class Tracker: class Tracker:
def __init__(self, config: Namespace, is_running: Event): def __init__(self, config: Namespace):
self.config = config self.config = config
self.is_running = is_running
context = zmq.Context() context = zmq.Context()
self.frame_sock = context.socket(zmq.SUB) self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.connect(config.zmq_frame_addr) self.frame_sock.connect(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.trajectory_socket = context.socket(zmq.PUB) self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
self.trajectory_socket.bind(config.zmq_trajectory_addr) self.trajectory_socket.bind(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
# # TODO: config device # # TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -56,13 +52,10 @@ class Tracker:
def track(self): def track(self):
while self.is_running.is_set(): while True:
msg = self.frame_sock.recv() frame = pickle.loads(self.frame_sock.recv())
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s detections = self.detect_persons(frame)
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s") tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame)
start_time = time.time()
detections = self.detect_persons(frame.img)
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
TEMP_boxes = [t.to_ltwh() for t in tracks] TEMP_boxes = [t.to_ltwh() for t in tracks]
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes]) TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
@ -79,19 +72,13 @@ class Tracker:
"id": tid, "id": tid,
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test "history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
} }
# logger.debug(f"{trajectories}") logger.debug(f"{trajectories}")
frame.trajectories = trajectories self.trajectory_socket.send_string(json.dumps(trajectories))
current_time = time.time()
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
self.trajectory_socket.send(pickle.dumps(frame))
# self.trajectory_socket.send_string(json.dumps(trajectories))
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}} # provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
# TODO: provide a track object that actually keeps history (unlike tracker) # TODO: provide a track object that actually keeps history (unlike tracker)
#TODO calculate fps (also for other loops to see asynchonity) #TODO calculate fps (also for other loops to see asynchonity)
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display # fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
logger.info('Stopping')
def detect_persons(self, frame) -> Detections: def detect_persons(self, frame) -> Detections:
@ -131,6 +118,6 @@ class Tracker:
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections] return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
def run_tracker(config: Namespace, is_running: Event): def run_tracker(config: Namespace):
router = Tracker(config, is_running) router = Tracker(config)
router.track() router.track()