Compare commits
4 commits
23162da767
...
821d06c9cf
Author | SHA1 | Date | |
---|---|---|---|
|
821d06c9cf | ||
|
3d34263a71 | ||
|
9b39d7cd9b | ||
|
27565d919e |
7 changed files with 234 additions and 51 deletions
|
@ -1,6 +1,8 @@
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pyparsing import Optional
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,6 +29,7 @@ inference_parser = parser.add_argument_group('Inference')
|
||||||
connection_parser = parser.add_argument_group('Connection')
|
connection_parser = parser.add_argument_group('Connection')
|
||||||
frame_emitter_parser = parser.add_argument_group('Frame emitter')
|
frame_emitter_parser = parser.add_argument_group('Frame emitter')
|
||||||
tracker_parser = parser.add_argument_group('Tracker')
|
tracker_parser = parser.add_argument_group('Tracker')
|
||||||
|
render_parser = parser.add_argument_group('Renderer')
|
||||||
|
|
||||||
inference_parser.add_argument("--model_dir",
|
inference_parser.add_argument("--model_dir",
|
||||||
help="directory with the model to use for inference",
|
help="directory with the model to use for inference",
|
||||||
|
@ -106,7 +109,7 @@ inference_parser.add_argument("--eval_data_dict",
|
||||||
|
|
||||||
inference_parser.add_argument("--output_dir",
|
inference_parser.add_argument("--output_dir",
|
||||||
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
|
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
|
||||||
type=str,
|
type=Path,
|
||||||
default='./OUT/test_inference')
|
default='./OUT/test_inference')
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,3 +177,11 @@ tracker_parser.add_argument("--homography",
|
||||||
help="File with homography params",
|
help="File with homography params",
|
||||||
type=Path,
|
type=Path,
|
||||||
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
|
default='../DATASETS/VIRAT_subset_0102x/VIRAT_0102_homography_img2world.txt')
|
||||||
|
|
||||||
|
# Renderer
|
||||||
|
|
||||||
|
# render_parser.add_argument("--output-dir",
|
||||||
|
# help="Target image dir",
|
||||||
|
# type=Optional[Path],
|
||||||
|
# default=None)
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,36 @@
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from dataclasses import dataclass, field
|
||||||
import logging
|
import logging
|
||||||
|
from multiprocessing import Event
|
||||||
import pickle
|
import pickle
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
logger = logging.getLogger('trap.frame_emitter')
|
logger = logging.getLogger('trap.frame_emitter')
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Frame:
|
||||||
|
img: np.array
|
||||||
|
time: float= field(default_factory=lambda: time.time())
|
||||||
|
trajectories: Optional[dict] = None
|
||||||
|
|
||||||
class FrameEmitter:
|
class FrameEmitter:
|
||||||
'''
|
'''
|
||||||
Emit frame in a separate threat so they can be throttled,
|
Emit frame in a separate threat so they can be throttled,
|
||||||
or thrown away when the rest of the system cannot keep up
|
or thrown away when the rest of the system cannot keep up
|
||||||
'''
|
'''
|
||||||
def __init__(self, config: Namespace) -> None:
|
def __init__(self, config: Namespace, is_running: Event) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_running = is_running
|
||||||
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
self.frame_sock = context.socket(zmq.PUB)
|
self.frame_sock = context.socket(zmq.PUB)
|
||||||
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. make sure to set BEFORE connect/bind
|
||||||
self.frame_sock.bind(config.zmq_frame_addr)
|
self.frame_sock.bind(config.zmq_frame_addr)
|
||||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
|
||||||
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
logger.info(f"Connection socket {config.zmq_frame_addr}")
|
||||||
|
|
||||||
def emit_video(self):
|
def emit_video(self):
|
||||||
|
@ -28,15 +39,15 @@ class FrameEmitter:
|
||||||
frame_duration = 1./fps
|
frame_duration = 1./fps
|
||||||
|
|
||||||
prev_time = time.time()
|
prev_time = time.time()
|
||||||
while True:
|
while self.is_running.is_set():
|
||||||
ret, frame = video.read()
|
ret, img = video.read()
|
||||||
|
|
||||||
# seek to 0 if video has finished. Infinite loop
|
# seek to 0 if video has finished. Infinite loop
|
||||||
if not ret:
|
if not ret:
|
||||||
video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
video.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||||
ret, frame = video.read()
|
ret, img = video.read()
|
||||||
assert ret is not False # not really error proof...
|
assert ret is not False # not really error proof...
|
||||||
|
frame = Frame(img=img)
|
||||||
# TODO: this is very dirty, need to find another way.
|
# TODO: this is very dirty, need to find another way.
|
||||||
# perhaps multiprocessing queue?
|
# perhaps multiprocessing queue?
|
||||||
self.frame_sock.send(pickle.dumps(frame))
|
self.frame_sock.send(pickle.dumps(frame))
|
||||||
|
@ -52,6 +63,6 @@ class FrameEmitter:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_frame_emitter(config: Namespace):
|
def run_frame_emitter(config: Namespace, is_running: Event):
|
||||||
router = FrameEmitter(config)
|
router = FrameEmitter(config, is_running)
|
||||||
router.emit_video()
|
router.emit_video()
|
|
@ -1,38 +1,63 @@
|
||||||
import logging
|
import logging
|
||||||
from logging.handlers import SocketHandler
|
from logging.handlers import SocketHandler
|
||||||
from multiprocessing import Process, Queue
|
from multiprocessing import Event, Process, Queue
|
||||||
|
import sys
|
||||||
from trap.config import parser
|
from trap.config import parser
|
||||||
from trap.frame_emitter import run_frame_emitter
|
from trap.frame_emitter import run_frame_emitter
|
||||||
from trap.prediction_server import InferenceServer, run_inference_server
|
from trap.prediction_server import run_prediction_server
|
||||||
|
from trap.renderer import run_renderer
|
||||||
from trap.socket_forwarder import run_ws_forwarder
|
from trap.socket_forwarder import run_ws_forwarder
|
||||||
from trap.tracker import run_tracker
|
from trap.tracker import run_tracker
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("trap.plumbing")
|
logger = logging.getLogger("trap.plumbing")
|
||||||
|
|
||||||
|
class ExceptionHandlingProcess(Process):
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
assert 'is_running' in self._kwargs
|
||||||
|
try:
|
||||||
|
super(Process, self).run()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
self._kwargs['is_running'].clear()
|
||||||
|
|
||||||
|
|
||||||
def start():
|
def start():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
|
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
|
||||||
|
# print(args)
|
||||||
|
# exit()
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=loglevel,
|
level=loglevel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# set per handler, so we can set it lower for the root logger if remote logging is enabled
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
[h.setLevel(loglevel) for h in root_logger.handlers]
|
||||||
|
|
||||||
|
isRunning = Event()
|
||||||
|
isRunning.set()
|
||||||
|
|
||||||
|
|
||||||
if args.remote_log_addr:
|
if args.remote_log_addr:
|
||||||
logging.captureWarnings(True)
|
logging.captureWarnings(True)
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.setLevel(logging.NOTSET) # to send all records to cutelog
|
root_logger.setLevel(logging.NOTSET) # to send all records to cutelog
|
||||||
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
|
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
|
||||||
root_logger.addHandler(socket_handler)
|
root_logger.addHandler(socket_handler)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# instantiating process with arguments
|
# instantiating process with arguments
|
||||||
procs = [
|
procs = [
|
||||||
Process(target=run_ws_forwarder, args=(args,), name='forwarder'),
|
ExceptionHandlingProcess(target=run_ws_forwarder, kwargs={'config': args, 'is_running': isRunning}, name='forwarder'),
|
||||||
Process(target=run_frame_emitter, args=(args,), name='frame_emitter'),
|
ExceptionHandlingProcess(target=run_frame_emitter, kwargs={'config': args, 'is_running': isRunning}, name='frame_emitter'),
|
||||||
Process(target=run_tracker, args=(args,), name='tracker'),
|
ExceptionHandlingProcess(target=run_tracker, kwargs={'config': args, 'is_running': isRunning}, name='tracker'),
|
||||||
|
ExceptionHandlingProcess(target=run_renderer, kwargs={'config': args, 'is_running': isRunning}, name='renderer'),
|
||||||
]
|
]
|
||||||
if not args.bypass_prediction:
|
if not args.bypass_prediction:
|
||||||
procs.append(
|
procs.append(
|
||||||
Process(target=run_inference_server, args=(args,), name='inference'),
|
ExceptionHandlingProcess(target=run_prediction_server, kwargs={'config': args, 'is_running':isRunning}, name='inference'),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("start")
|
logger.info("start")
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
# adapted from Trajectron++ online_server.py
|
# adapted from Trajectron++ online_server.py
|
||||||
|
from argparse import Namespace
|
||||||
import logging
|
import logging
|
||||||
from multiprocessing import Queue
|
from multiprocessing import Event, Queue
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
|
import traceback
|
||||||
|
import warnings
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
import dill
|
import dill
|
||||||
|
@ -21,6 +26,8 @@ import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
|
from trap.frame_emitter import Frame
|
||||||
|
|
||||||
logger = logging.getLogger("trap.prediction")
|
logger = logging.getLogger("trap.prediction")
|
||||||
|
|
||||||
|
|
||||||
|
@ -105,15 +112,19 @@ def get_maps_for_input(input_dict, scene, hyperparams):
|
||||||
return maps_dict
|
return maps_dict
|
||||||
|
|
||||||
|
|
||||||
class InferenceServer:
|
class PredictionServer:
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: Namespace, is_running: Event):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_running = is_running
|
||||||
|
|
||||||
|
if self.config.eval_device == 'cpu':
|
||||||
|
logger.warning("Running on CPU. Specifying --eval_device cuda:0 should dramatically speed up prediction")
|
||||||
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
||||||
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
|
||||||
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg
|
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg. Set BEFORE connect!
|
||||||
|
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
||||||
|
|
||||||
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
|
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
|
||||||
self.prediction_socket.bind(config.zmq_prediction_addr)
|
self.prediction_socket.bind(config.zmq_prediction_addr)
|
||||||
|
@ -196,7 +207,7 @@ class InferenceServer:
|
||||||
trajectron.set_environment(online_env, init_timestep)
|
trajectron.set_environment(online_env, init_timestep)
|
||||||
|
|
||||||
timestep = init_timestep + 1
|
timestep = init_timestep + 1
|
||||||
while True:
|
while self.is_running.is_set():
|
||||||
timestep += 1
|
timestep += 1
|
||||||
# for timestep in range(init_timestep + 1, eval_scene.timesteps):
|
# for timestep in range(init_timestep + 1, eval_scene.timesteps):
|
||||||
|
|
||||||
|
@ -220,8 +231,10 @@ class InferenceServer:
|
||||||
# node_data = pd.DataFrame(data_dict, columns=data_columns)
|
# node_data = pd.DataFrame(data_dict, columns=data_columns)
|
||||||
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
|
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
|
||||||
|
|
||||||
data = self.trajectory_socket.recv_string()
|
data = self.trajectory_socket.recv()
|
||||||
trajectory_data = json.loads(data)
|
frame: Frame = pickle.loads(data)
|
||||||
|
trajectory_data = frame.trajectories # TODO: properly refractor
|
||||||
|
# trajectory_data = json.loads(data)
|
||||||
logger.debug(f"Receive {trajectory_data}")
|
logger.debug(f"Receive {trajectory_data}")
|
||||||
|
|
||||||
# class FakeNode:
|
# class FakeNode:
|
||||||
|
@ -291,12 +304,12 @@ class InferenceServer:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
dists, preds = trajectron.incremental_forward(input_dict,
|
dists, preds = trajectron.incremental_forward(input_dict,
|
||||||
maps,
|
maps,
|
||||||
prediction_horizon=16, # TODO: make variable
|
prediction_horizon=20, # TODO: make variable
|
||||||
num_samples=3, # TODO: make variable
|
num_samples=2, # TODO: make variable
|
||||||
robot_present_and_future=robot_present_and_future,
|
robot_present_and_future=robot_present_and_future,
|
||||||
full_dist=True)
|
full_dist=True)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
|
logger.debug("took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (end - start,
|
||||||
1. / (end - start), len(trajectron.nodes),
|
1. / (end - start), len(trajectron.nodes),
|
||||||
trajectron.scene_graph.get_num_edges()))
|
trajectron.scene_graph.get_num_edges()))
|
||||||
|
|
||||||
|
@ -343,10 +356,21 @@ class InferenceServer:
|
||||||
}
|
}
|
||||||
|
|
||||||
data = json.dumps(response)
|
data = json.dumps(response)
|
||||||
|
logger.info(f"Total frame delay = {time.time()-frame.time}s ({len(trajectron.nodes)} nodes & {trajectron.scene_graph.get_num_edges()} edges. Trajectron: {end - start}s)")
|
||||||
self.prediction_socket.send_string(data)
|
self.prediction_socket.send_string(data)
|
||||||
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_inference_server(config):
|
def run_prediction_server(config: Namespace, is_running: Event):
|
||||||
s = InferenceServer(config)
|
|
||||||
|
# attempt to trace the warnings coming from pytorch
|
||||||
|
# def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
|
||||||
|
# log = file if hasattr(file,'write') else sys.stderr
|
||||||
|
# traceback.print_stack(file=log)
|
||||||
|
# log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
||||||
|
|
||||||
|
# warnings.showwarning = warn_with_traceback
|
||||||
|
s = PredictionServer(config, is_running)
|
||||||
s.run()
|
s.run()
|
91
trap/renderer.py
Normal file
91
trap/renderer.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
|
import logging
|
||||||
|
from multiprocessing import Event
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from trap.frame_emitter import Frame
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("trap.renderer")
|
||||||
|
|
||||||
|
class Renderer:
|
||||||
|
def __init__(self, config: Namespace, is_running: Event):
|
||||||
|
self.config = config
|
||||||
|
self.is_running = is_running
|
||||||
|
|
||||||
|
context = zmq.Context()
|
||||||
|
self.prediction_sock = context.socket(zmq.SUB)
|
||||||
|
self.prediction_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
|
self.prediction_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
|
self.prediction_sock.connect(config.zmq_prediction_addr)
|
||||||
|
|
||||||
|
self.frame_sock = context.socket(zmq.SUB)
|
||||||
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
|
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
|
self.frame_sock.connect(config.zmq_frame_addr)
|
||||||
|
|
||||||
|
|
||||||
|
H = np.loadtxt(self.config.homography, delimiter=',')
|
||||||
|
|
||||||
|
self.inv_H = np.linalg.pinv(H)
|
||||||
|
|
||||||
|
if not self.config.output_dir.exists():
|
||||||
|
raise FileNotFoundError("Path does not exist")
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
predictions = {}
|
||||||
|
i=0
|
||||||
|
first_time = None
|
||||||
|
while self.is_running.is_set():
|
||||||
|
i+=1
|
||||||
|
frame: Frame = self.frame_sock.recv_pyobj()
|
||||||
|
try:
|
||||||
|
predictions = self.prediction_sock.recv_json(zmq.NOBLOCK)
|
||||||
|
except zmq.ZMQError as e:
|
||||||
|
logger.debug(f'reuse prediction')
|
||||||
|
|
||||||
|
img = frame.img
|
||||||
|
for track_id, prediction in predictions.items():
|
||||||
|
if not 'history' in prediction or not len(prediction['history']):
|
||||||
|
continue
|
||||||
|
coords = cv2.perspectiveTransform(np.array([prediction['history']]), self.inv_H)[0]
|
||||||
|
# logger.warning(f"{coords=}")
|
||||||
|
center = [int(p) for p in coords[-1]]
|
||||||
|
cv2.circle(img, center, 5, (0,255,0))
|
||||||
|
|
||||||
|
for ci in range(1, len(coords)):
|
||||||
|
start = [int(p) for p in coords[ci-1]]
|
||||||
|
end = [int(p) for p in coords[ci]]
|
||||||
|
cv2.line(img, start, end, (255,255,255), 2)
|
||||||
|
|
||||||
|
if not 'predictions' in prediction or not len(prediction['predictions']):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for pred in prediction['predictions']:
|
||||||
|
pred_coords = cv2.perspectiveTransform(np.array([pred]), self.inv_H)[0]
|
||||||
|
for ci in range(1, len(pred_coords)):
|
||||||
|
start = [int(p) for p in pred_coords[ci-1]]
|
||||||
|
end = [int(p) for p in pred_coords[ci]]
|
||||||
|
cv2.line(img, start, end, (0,0,255), 2)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if first_time is None:
|
||||||
|
first_time = frame.time
|
||||||
|
|
||||||
|
cv2.putText(img, f"{frame.time - first_time:.3f}s", (20,50), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,0), 1)
|
||||||
|
|
||||||
|
img_path = (self.config.output_dir / f"{i:05d}.png").resolve()
|
||||||
|
|
||||||
|
cv2.imwrite(str(img_path), img)
|
||||||
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_renderer(config: Namespace, is_running: Event):
|
||||||
|
renderer = Renderer(config, is_running)
|
||||||
|
renderer.run()
|
|
@ -2,6 +2,7 @@
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from multiprocessing import Event
|
||||||
from typing import Set, Union, Dict, Any
|
from typing import Set, Union, Dict, Any
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
@ -103,8 +104,9 @@ class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
|
||||||
|
|
||||||
|
|
||||||
class WsRouter:
|
class WsRouter:
|
||||||
def __init__(self, config: Namespace):
|
def __init__(self, config: Namespace, is_running: Event):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_running = is_running
|
||||||
|
|
||||||
context = zmq.asyncio.Context()
|
context = zmq.asyncio.Context()
|
||||||
self.trajectory_socket = context.socket(zmq.PUB)
|
self.trajectory_socket = context.socket(zmq.PUB)
|
||||||
|
@ -138,25 +140,31 @@ class WsRouter:
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
|
|
||||||
evt_loop = asyncio.new_event_loop()
|
self.evt_loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(evt_loop)
|
asyncio.set_event_loop(self.evt_loop)
|
||||||
|
|
||||||
# loop = tornado.ioloop.IOLoop.current()
|
# loop = tornado.ioloop.IOLoop.current()
|
||||||
logger.info(f"Listen on {self.config.ws_port}")
|
logger.info(f"Listen on {self.config.ws_port}")
|
||||||
self.application.listen(self.config.ws_port)
|
self.application.listen(self.config.ws_port)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
task = evt_loop.create_task(self.prediction_forwarder())
|
task = self.evt_loop.create_task(self.prediction_forwarder())
|
||||||
|
|
||||||
|
self.evt_loop.run_forever()
|
||||||
|
|
||||||
evt_loop.run_forever()
|
|
||||||
|
|
||||||
async def prediction_forwarder(self):
|
async def prediction_forwarder(self):
|
||||||
logger.info("Starting prediction forwarder")
|
logger.info("Starting prediction forwarder")
|
||||||
while True:
|
while self.is_running.is_set():
|
||||||
msg = await self.prediction_socket.recv_string()
|
msg = await self.prediction_socket.recv_string()
|
||||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||||
WebSocketPredictionHandler.write_to_clients(msg)
|
WebSocketPredictionHandler.write_to_clients(msg)
|
||||||
|
|
||||||
def run_ws_forwarder(config: Namespace):
|
# die together:
|
||||||
router = WsRouter(config)
|
self.evt_loop.stop()
|
||||||
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
||||||
|
def run_ws_forwarder(config: Namespace, is_running: Event):
|
||||||
|
router = WsRouter(config, is_running)
|
||||||
router.start()
|
router.start()
|
|
@ -1,6 +1,7 @@
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from multiprocessing import Event
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -12,24 +13,27 @@ from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_Re
|
||||||
from deep_sort_realtime.deepsort_tracker import DeepSort
|
from deep_sort_realtime.deepsort_tracker import DeepSort
|
||||||
from deep_sort_realtime.deep_sort.track import Track
|
from deep_sort_realtime.deep_sort.track import Track
|
||||||
|
|
||||||
|
from trap.frame_emitter import Frame
|
||||||
|
|
||||||
Detection = [int, int, int, int, float, int]
|
Detection = [int, int, int, int, float, int]
|
||||||
Detections = [Detection]
|
Detections = [Detection]
|
||||||
|
|
||||||
logger = logging.getLogger("trap.tracker")
|
logger = logging.getLogger("trap.tracker")
|
||||||
|
|
||||||
class Tracker:
|
class Tracker:
|
||||||
def __init__(self, config: Namespace):
|
def __init__(self, config: Namespace, is_running: Event):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_running = is_running
|
||||||
|
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
self.frame_sock = context.socket(zmq.SUB)
|
self.frame_sock = context.socket(zmq.SUB)
|
||||||
self.frame_sock.connect(config.zmq_frame_addr)
|
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame. NB. make sure this comes BEFORE connect, otherwise it's ignored!!
|
||||||
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
self.frame_sock.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
self.frame_sock.connect(config.zmq_frame_addr)
|
||||||
|
|
||||||
self.trajectory_socket = context.socket(zmq.PUB)
|
self.trajectory_socket = context.socket(zmq.PUB)
|
||||||
self.trajectory_socket.bind(config.zmq_trajectory_addr)
|
|
||||||
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
|
||||||
|
self.trajectory_socket.bind(config.zmq_trajectory_addr)
|
||||||
|
|
||||||
# # TODO: config device
|
# # TODO: config device
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
@ -52,10 +56,13 @@ class Tracker:
|
||||||
|
|
||||||
|
|
||||||
def track(self):
|
def track(self):
|
||||||
while True:
|
while self.is_running.is_set():
|
||||||
frame = pickle.loads(self.frame_sock.recv())
|
msg = self.frame_sock.recv()
|
||||||
detections = self.detect_persons(frame)
|
frame: Frame = pickle.loads(msg) # frame delivery in current setup: 0.012-0.03s
|
||||||
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame)
|
# logger.info(f"Frame delivery delay = {time.time()-frame.time}s")
|
||||||
|
start_time = time.time()
|
||||||
|
detections = self.detect_persons(frame.img)
|
||||||
|
tracks: [Track] = self.mot_tracker.update_tracks(detections, frame=frame.img)
|
||||||
|
|
||||||
TEMP_boxes = [t.to_ltwh() for t in tracks]
|
TEMP_boxes = [t.to_ltwh() for t in tracks]
|
||||||
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
|
TEMP_coords = np.array([[[det[0] + 0.5 * det[2], det[1]+det[3]]] for det in TEMP_boxes])
|
||||||
|
@ -72,13 +79,19 @@ class Tracker:
|
||||||
"id": tid,
|
"id": tid,
|
||||||
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
|
"history": [{"x":c[0], "y":c[1]} for c in coords] # already doubles nested, fine for test
|
||||||
}
|
}
|
||||||
logger.debug(f"{trajectories}")
|
# logger.debug(f"{trajectories}")
|
||||||
self.trajectory_socket.send_string(json.dumps(trajectories))
|
frame.trajectories = trajectories
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
logger.debug(f"Trajectories: {len(trajectories)}. Current frame delay = {current_time-frame.time}s (trajectories: {current_time - start_time}s)")
|
||||||
|
self.trajectory_socket.send(pickle.dumps(frame))
|
||||||
|
# self.trajectory_socket.send_string(json.dumps(trajectories))
|
||||||
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
# provide a {ID: {id: ID, history: [[x,y],[x,y],...]}}
|
||||||
# TODO: provide a track object that actually keeps history (unlike tracker)
|
# TODO: provide a track object that actually keeps history (unlike tracker)
|
||||||
|
|
||||||
#TODO calculate fps (also for other loops to see asynchonity)
|
#TODO calculate fps (also for other loops to see asynchonity)
|
||||||
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
# fpsfilter=fpsfilter*.9+(1/dt)*.1 #trust value in order to stabilize fps display
|
||||||
|
logger.info('Stopping')
|
||||||
|
|
||||||
|
|
||||||
def detect_persons(self, frame) -> Detections:
|
def detect_persons(self, frame) -> Detections:
|
||||||
|
@ -118,6 +131,6 @@ class Tracker:
|
||||||
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
|
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
|
||||||
|
|
||||||
|
|
||||||
def run_tracker(config: Namespace):
|
def run_tracker(config: Namespace, is_running: Event):
|
||||||
router = Tracker(config)
|
router = Tracker(config, is_running)
|
||||||
router.track()
|
router.track()
|
Loading…
Reference in a new issue