Trajectory prediction - test in browser

This commit is contained in:
Ruben van de Ven 2023-10-12 20:28:17 +02:00
commit 7c06913d88
12 changed files with 4305 additions and 0 deletions

1
.python-version Normal file
View file

@ -0,0 +1 @@
3.10.4

3229
poetry.lock generated Normal file

File diff suppressed because it is too large Load diff

15
pyproject.toml Normal file
View file

@ -0,0 +1,15 @@
[tool.poetry]
name = "trap"
version = "0.1.0"
description = "Art installation with trajectory prediction"
authors = ["Ruben van de Ven <git@rubenvandeven.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.10,<3.12,"
trajectron-plus-plus = { path = "../Trajectron-plus-plus/", develop = true }
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

5
run_server.py Normal file
View file

@ -0,0 +1,5 @@
from trap import plumber
if __name__ == "__main__":
plumber.start()

0
trap/__init__.py Normal file
View file

168
trap/config.py Normal file
View file

@ -0,0 +1,168 @@
import argparse
from pathlib import Path
parser = argparse.ArgumentParser()
parser.add_argument(
'--verbose',
'-v',
help="Increase verbosity. Add multiple times to increase further.",
action='count', default=0
)
parser.add_argument(
'--remote-log-addr',
help="Connect to a remote logger like cutelog. Specify the ip",
type=str,
)
parser.add_argument(
'--remote-log-port',
help="Connect to a remote logger like cutelog. Specify the port",
type=int,
default=19996
)
# parser.add_argument('--foo')
inference_parser = parser.add_argument_group('inference server')
connection_parser = parser.add_argument_group('connection')
frame_emitter_parser = parser.add_argument_group('Frame emitter')
inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference",
type=str, # TODO: make into Path
default='../Trajectron-plus-plus/experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3')
inference_parser.add_argument("--conf",
help="path to json config file for hyperparameters, relative to model_dir",
type=str,
default='config.json')
# Model Parameters (hyperparameters)
inference_parser.add_argument("--offline_scene_graph",
help="whether to precompute the scene graphs offline, options are 'no' and 'yes'",
type=str,
default='yes')
inference_parser.add_argument("--dynamic_edges",
help="whether to use dynamic edges or not, options are 'no' and 'yes'",
type=str,
default='yes')
inference_parser.add_argument("--edge_state_combine_method",
help="the method to use for combining edges of the same type",
type=str,
default='sum')
inference_parser.add_argument("--edge_influence_combine_method",
help="the method to use for combining edge influences",
type=str,
default='attention')
inference_parser.add_argument('--edge_addition_filter',
nargs='+',
help="what scaling to use for edges as they're created",
type=float,
default=[0.25, 0.5, 0.75, 1.0]) # We don't automatically pad left with 0.0, if you want a sharp
# and short edge addition, then you need to have a 0.0 at the
# beginning, e.g. [0.0, 1.0].
inference_parser.add_argument('--edge_removal_filter',
nargs='+',
help="what scaling to use for edges as they're removed",
type=float,
default=[1.0, 0.0]) # We don't automatically pad right with 0.0, if you want a sharp drop off like
# the default, then you need to have a 0.0 at the end.
inference_parser.add_argument('--incl_robot_node',
help="whether to include a robot node in the graph or simply model all agents",
action='store_true')
inference_parser.add_argument('--map_encoding',
help="Whether to use map encoding or not",
action='store_true')
inference_parser.add_argument('--no_edge_encoding',
help="Whether to use neighbors edge encoding",
action='store_true')
inference_parser.add_argument('--batch_size',
help='training batch size',
type=int,
default=256)
inference_parser.add_argument('--k_eval',
help='how many samples to take during evaluation',
type=int,
default=25)
# Data Parameters
inference_parser.add_argument("--eval_data_dict",
help="what file to load for evaluation data (WHEN NOT USING LIVE DATA)",
type=str,
default='../Trajectron-plus-plus/experiments/processed/eth_test.pkl')
inference_parser.add_argument("--output_dir",
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
type=str,
default='./OUT/test_inference')
# inference_parser.add_argument('--device',
# help='what device to perform training on',
# type=str,
# default='cuda:0')
inference_parser.add_argument("--eval_device",
help="what device to use during inference",
type=str,
default="cpu")
inference_parser.add_argument('--seed',
help='manual seed to use, default is 123',
type=int,
default=123)
# Internal connections.
connection_parser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds/traj")
connection_parser.add_argument('--zmq-camera-stream-addr',
help='Manually specity communication addr for the camera stream messages',
type=str,
default="ipc:///tmp/feeds/img")
connection_parser.add_argument('--zmq-prediction-addr',
help='Manually specity communication addr for the prediction messages',
type=str,
default="ipc:///tmp/feeds/preds")
connection_parser.add_argument('--zmq-frame-addr',
help='Manually specity communication addr for the frame messages',
type=str,
default="ipc:///tmp/feeds/frame")
connection_parser.add_argument('--ws-port',
help='Port to listen for incomming websocket connections. Also serves the testing html-page.',
type=int,
default=8888)
connection_parser.add_argument('--bypass-prediction',
help='For debugging purpose: websocket input immediately to output',
action='store_true')
# Frame emitter
frame_emitter_parser.add_argument("--video-src",
help="source video to track from",
type=Path,
default='../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4')
#TODO: camera

51
trap/frame_emitter.py Normal file
View file

@ -0,0 +1,51 @@
from argparse import Namespace
import time
import cv2
import zmq
class FrameEmitter:
'''
Emit frame in a separate threat so they can be throttled,
or thrown away when the rest of the system cannot keep up
'''
def __init__(self, config: Namespace) -> None:
self.config = config
context = zmq.Context()
self.frame_sock = context.socket(zmq.PUB)
self.frame_sock.bind(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
def emit_video(self):
video = cv2.VideoCapture(str(self.config.video_src))
fps = video.get(cv2.CAP_PROP_FPS)
frame_duration = 1./fps
prev_time = time.time()
while True:
ret, frame = video.read()
# seek to 0 if video has finished. Infinite loop
if not ret:
video.set(cv2.CAP_PROP_POS_FRAMES, 0)
ret, frame = video.read()
assert ret is not False # not really error proof...
self.frame_sock.send(frame)
# defer next loop
new_frame_time = time.time()
time_diff = (new_frame_time - prev_time)
if time_diff < frame_duration:
time.sleep(frame_duration - time_diff)
new_frame_time += frame_duration - time_diff
else:
prev_time = new_frame_time
def run_frame_emitter(config: Namespace):
router = FrameEmitter(config)
router.emit_video()

42
trap/plumber.py Normal file
View file

@ -0,0 +1,42 @@
import logging
from logging.handlers import SocketHandler
from multiprocessing import Process, Queue
from trap.config import parser
from trap.frame_emitter import run_frame_emitter
from trap.prediction_server import InferenceServer, run_inference_server
from trap.socket_forwarder import run_ws_forwarder
logger = logging.getLogger("trap.plumbing")
def start():
args = parser.parse_args()
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
logging.basicConfig(
level=loglevel,
)
if args.remote_log_addr:
logging.captureWarnings(True)
root_logger = logging.getLogger()
root_logger.setLevel(logging.NOTSET) # to send all records to cutelog
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
root_logger.addHandler(socket_handler)
# instantiating process with arguments
procs = [
# Process(target=run_ws_forwarder, args=(args,)),
Process(target=run_frame_emitter, args=(args,)),
]
if not args.bypass_prediction:
procs.append(
Process(target=run_inference_server, args=(args,)),
)
logger.info("start")
for proc in procs:
proc.start()
for proc in procs:
proc.join()

352
trap/prediction_server.py Normal file
View file

@ -0,0 +1,352 @@
# adapted from Trajectron++ online_server.py
import logging
from multiprocessing import Queue
import os
import time
import json
import pandas as pd
import torch
import dill
import random
import pathlib
import numpy as np
from trajectron.environment.data_utils import derivative_of
from trajectron.utils import prediction_output_to_trajectories
from trajectron.model.online.online_trajectron import OnlineTrajectron
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.environment import Environment, Scene
from trajectron.environment.node import Node
from trajectron.environment.node_type import NodeType
import matplotlib.pyplot as plt
import zmq
logger = logging.getLogger("trap.inference")
# if not torch.cuda.is_available() or self.config.device == 'cpu':
# self.config.device = torch.device('cpu')
# else:
# if torch.cuda.device_count() == 1:
# # If you have CUDA_VISIBLE_DEVICES set, which you should,
# # then this will prevent leftover flag arguments from
# # messing with the device allocation.
# self.config.device = 'cuda:0'
# self.config.device = torch.device(self.config.device)
def create_online_env(env, hyperparams, scene_idx, init_timestep):
test_scene = env.scenes[scene_idx]
online_scene = Scene(timesteps=init_timestep + 1,
map=test_scene.map,
dt=test_scene.dt)
online_scene.nodes = test_scene.get_nodes_clipped_at_time(
timesteps=np.arange(init_timestep - hyperparams['maximum_history_length'],
init_timestep + 1),
state=hyperparams['state'])
online_scene.robot = test_scene.robot
online_scene.calculate_scene_graph(attention_radius=env.attention_radius,
edge_addition_filter=hyperparams['edge_addition_filter'],
edge_removal_filter=hyperparams['edge_removal_filter'])
return Environment(node_type_list=env.node_type_list,
standardization=env.standardization,
scenes=[online_scene],
attention_radius=env.attention_radius,
robot_type=env.robot_type)
def get_maps_for_input(input_dict, scene, hyperparams):
scene_maps = list()
scene_pts = list()
heading_angles = list()
patch_sizes = list()
nodes_with_maps = list()
for node in input_dict:
if node.type in hyperparams['map_encoder']:
x = input_dict[node]
me_hyp = hyperparams['map_encoder'][node.type]
if 'heading_state_index' in me_hyp:
heading_state_index = me_hyp['heading_state_index']
# We have to rotate the map in the opposit direction of the agent to match them
if type(heading_state_index) is list: # infer from velocity or heading vector
heading_angle = -np.arctan2(x[-1, heading_state_index[1]],
x[-1, heading_state_index[0]]) * 180 / np.pi
else:
heading_angle = -x[-1, heading_state_index] * 180 / np.pi
else:
heading_angle = None
scene_map = scene.map[node.type]
map_point = x[-1, :2]
patch_size = hyperparams['map_encoder'][node.type]['patch_size']
scene_maps.append(scene_map)
scene_pts.append(map_point)
heading_angles.append(heading_angle)
patch_sizes.append(patch_size)
nodes_with_maps.append(node)
if heading_angles[0] is None:
heading_angles = None
else:
heading_angles = torch.Tensor(heading_angles)
maps = scene_maps[0].get_cropped_maps_from_scene_map_batch(scene_maps,
scene_pts=torch.Tensor(scene_pts),
patch_size=patch_sizes[0],
rotation=heading_angles)
maps_dict = {node: maps[[i]] for i, node in enumerate(nodes_with_maps)}
return maps_dict
class InferenceServer:
def __init__(self, config: dict):
self.config = config
context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr)
print(self.prediction_socket)
def run(self):
if self.config.seed is not None:
random.seed(self.config.seed)
np.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.config.seed)
# Choose one of the model directory names under the experiment/*/models folders.
# Possibilities are 'vel_ee', 'int_ee', 'int_ee_me', or 'robot'
# model_dir = os.path.join(self.config.log_dir, 'int_ee')
# model_dir = 'models/models_04_Oct_2023_21_04_48_eth_vel_ar3'
# Load hyperparameters from json
config_file = os.path.join(self.config.model_dir, self.config.conf)
if not os.path.exists(config_file):
raise ValueError('Config json not found!')
with open(config_file, 'r') as conf_json:
hyperparams = json.load(conf_json)
# Add hyperparams from arguments
hyperparams['dynamic_edges'] = self.config.dynamic_edges
hyperparams['edge_state_combine_method'] = self.config.edge_state_combine_method
hyperparams['edge_influence_combine_method'] = self.config.edge_influence_combine_method
hyperparams['edge_addition_filter'] = self.config.edge_addition_filter
hyperparams['edge_removal_filter'] = self.config.edge_removal_filter
hyperparams['batch_size'] = self.config.batch_size
hyperparams['k_eval'] = self.config.k_eval
hyperparams['offline_scene_graph'] = self.config.offline_scene_graph
hyperparams['incl_robot_node'] = self.config.incl_robot_node
hyperparams['edge_encoding'] = not self.config.no_edge_encoding
hyperparams['use_map_encoding'] = self.config.map_encoding
# hyperparams['maximum_history_length'] = 12 # test
logger.info(f"Use hyperparams: {hyperparams=}")
output_save_dir = os.path.join(self.config.output_dir, 'pred_figs')
pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)
with open(self.config.eval_data_dict, 'rb') as f:
eval_env = dill.load(f, encoding='latin1')
if eval_env.robot_type is None and hyperparams['incl_robot_node']:
eval_env.robot_type = eval_env.NodeType[0] # TODO: Make more general, allow the user to specify?
for scene in eval_env.scenes:
scene.add_robot_from_nodes(eval_env.robot_type)
logger.info('Loaded data from %s' % (self.config.eval_data_dict,))
# Creating a dummy environment with a single scene that contains information about the world.
# When using this code, feel free to use whichever scene index or initial timestep you wish.
scene_idx = 0
# You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1],
# so that you can immediately start incremental inference from the 3rd timestep onwards.
init_timestep = 1
eval_scene = eval_env.scenes[scene_idx]
online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep)
model_registrar = ModelRegistrar(self.config.model_dir, self.config.eval_device)
model_registrar.load_models(iter_num=100)
trajectron = OnlineTrajectron(model_registrar,
hyperparams,
self.config.eval_device)
# If you want to see what different robot futures do to the predictions, uncomment this line as well as
# related "... += adjustment" lines below.
# adjustment = np.stack([np.arange(13)/float(i*2.0) for i in range(6, 12)], axis=1)
# Here's how you'd incrementally run the model, e.g. with streaming data.
trajectron.set_environment(online_env, init_timestep)
timestep = init_timestep + 1
while True:
timestep += 1
# for timestep in range(init_timestep + 1, eval_scene.timesteps):
# input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
# TODO: see process_data.py on how to create a node, the provide nodes + incoming data columns
# data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
# x = node_values[:, 0]
# y = node_values[:, 1]
# vx = derivative_of(x, scene.dt)
# vy = derivative_of(y, scene.dt)
# ax = derivative_of(vx, scene.dt)
# ay = derivative_of(vy, scene.dt)
# data_dict = {('position', 'x'): x,
# ('position', 'y'): y,
# ('velocity', 'x'): vx,
# ('velocity', 'y'): vy,
# ('acceleration', 'x'): ax,
# ('acceleration', 'y'): ay}
# node_data = pd.DataFrame(data_dict, columns=data_columns)
# node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data)
data = self.trajectory_socket.recv_string()
trajectory_data = json.loads(data)
logger.info(f"Receive {trajectory_data}")
# class FakeNode:
# def __init__(self, node_type: NodeType):
# self.type = node_type
input_dict = {}
for identifier, trajectory in trajectory_data.items():
# if len(trajectory['history']) < 7:
# # TODO: these trajectories should still be in the output, but without predictions
# continue
# TODO: modify this into a mapping function between JS data an the expected Node format
# node = FakeNode(online_env.NodeType.PEDESTRIAN)
history = [[h['x'], h['y']] for h in trajectory['history']]
history = np.array(history)
x = history[:, 0]
y = history[:, 1]
# TODO: calculate dt based on input
vx = derivative_of(x, 0.2) #eval_scene.dt
vy = derivative_of(y, 0.2)
ax = derivative_of(vx, 0.2)
ay = derivative_of(vy, 0.2)
data_dict = {('position', 'x'): x[:],
('position', 'y'): y[:],
('velocity', 'x'): vx[:],
('velocity', 'y'): vy[:],
('acceleration', 'x'): ax[:],
('acceleration', 'y'): ay[:]}
data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']])
node_data = pd.DataFrame(data_dict, columns=data_columns)
node = Node(
node_type=online_env.NodeType.PEDESTRIAN,
node_id=identifier,
data=node_data,
first_timestep=timestep
)
input_dict[node] = np.array([x[-1],y[-1],vx[-1],vy[-1],ax[-1],ay[-1]])
# print(input_dict)
if not len(input_dict):
# skip if our input is empty
# TODO: we want to send out empty result...
data = json.dumps({})
self.prediction_socket.send_string(data)
continue
maps = None
if hyperparams['use_map_encoding']:
maps = get_maps_for_input(input_dict, eval_scene, hyperparams)
# print(maps)
robot_present_and_future = None
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
robot_present_and_future = eval_scene.robot.get(np.array([timestep,
timestep + hyperparams['prediction_horizon']]),
hyperparams['state'][eval_scene.robot.type],
padding=0.0)
robot_present_and_future = np.stack([robot_present_and_future, robot_present_and_future], axis=0)
# robot_present_and_future += adjustment
start = time.time()
dists, preds = trajectron.incremental_forward(input_dict,
maps,
prediction_horizon=16, # TODO: make variable
num_samples=3, # TODO: make variable
robot_present_and_future=robot_present_and_future,
full_dist=True)
end = time.time()
logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
1. / (end - start), len(trajectron.nodes),
trajectron.scene_graph.get_num_edges()))
# unsure what this bit from online_prediction.py does:
# detailed_preds_dict = dict()
# for node in eval_scene.nodes:
# if node in preds:
# detailed_preds_dict[node] = preds[node]
#adapted from trajectron.visualization
# prediction_dict provides the actual predictions
# histories_dict provides the trajectory used for prediction
# futures_dict is the Ground Truth, which is unvailable in an online setting
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
eval_scene.dt,
hyperparams['maximum_history_length'],
hyperparams['prediction_horizon']
)
assert(len(prediction_dict.keys()) <= 1)
if len(prediction_dict.keys()) == 0:
return
ts_key = list(prediction_dict.keys())[0]
prediction_dict = prediction_dict[ts_key]
histories_dict = histories_dict[ts_key]
futures_dict = futures_dict[ts_key]
response = {}
print(histories_dict)
for node in histories_dict:
history = histories_dict[node]
# future = futures_dict[node]
predictions = prediction_dict[node]
if not len(history) or np.isnan(history[-1]).any():
continue
response[node.id] = {
'id': node.id,
'history': history.tolist(),
'predictions': predictions[0].tolist() # use batch 0
}
data = json.dumps(response)
self.prediction_socket.send_string(data)
def run_inference_server(config):
s = InferenceServer(config)
s.run()

162
trap/socket_forwarder.py Normal file
View file

@ -0,0 +1,162 @@
from argparse import Namespace
import asyncio
import logging
from typing import Set, Union, Dict, Any
from typing_extensions import Self
from urllib.error import HTTPError
import tornado.ioloop
import tornado.web
import tornado.websocket
import zmq
import zmq.asyncio
logger = logging.getLogger("trap.forwarder")
class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler):
def initialize(self, zmq_socket: zmq.asyncio.Socket):
self.zmq_socket = zmq_socket
async def on_message(self, message):
logger.debug(f"recieve msg")
try:
await self.zmq_socket.send_string(message)
# msg = json.loads(message)
except Exception as e:
# self.send({'alert': 'Invalid request: {}'.format(e)})
logger.exception(e)
# self.write_message(u"You said: " + message)
def open(self, p=None):
logger.info(f"connected {self.request.remote_ip}")
# client disconnected
def on_close(self):
logger.info(f"Client disconnected: {self.request.remote_ip}")
class WebSocketPredictionHandler(tornado.websocket.WebSocketHandler):
connections: Set[Self] = set()
def initialize(self, config):
self.config = config
def on_message(self, message):
logger.warning(f"Receiving message on send-only ws handler: {message}")
def open(self, p=None):
logger.info(f"Prediction WS connected {self.request.remote_ip}")
self.__class__.connections.add(self)
# client disconnected
def on_close(self):
self.__class__.rmConnection(self)
logger.info(f"Client disconnected: {self.request.remote_ip}")
@classmethod
def rmConnection(cls, client):
if client not in cls.connections:
return
cls.connections.remove(client)
@classmethod
def hasConnection(cls, client):
return client in cls.connections
@classmethod
def write_to_clients(cls, msg: Union[bytes, str, Dict[str, Any]]):
if msg is None:
logger.critical("Tried to send 'none'")
return
toRemove = []
for client in cls.connections:
try:
client.write_message(msg)
except tornado.websocket.WebSocketClosedError as e:
logger.warning(f"Not properly closed websocket connection")
toRemove.append(client) # If we remove it here from the set we get an exception about changing set size during iteration
for client in toRemove:
cls.rmConnection(client)
class DemoHandler(tornado.web.RequestHandler):
def initialize(self, config: Namespace):
self.config = config
def get(self):
self.render("index.html", ws_port=self.config.ws_port)
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
def set_extra_headers(self, path):
"""For subclass to add extra headers to the response"""
if path[-5:] == ".html":
self.set_header("Access-Control-Allow-Origin", "*")
if path[-4:] == ".svg":
self.set_header("Content-Type", "image/svg+xml")
class WsRouter:
def __init__(self, config: Namespace):
self.config = config
context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.connect(config.zmq_prediction_addr)
self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.application = tornado.web.Application(
[
(
r"/ws/trajectory",
WebSocketTrajectoryHandler,
{
"zmq_socket": self.trajectory_socket
},
),
(
r"/ws/prediction",
WebSocketPredictionHandler,
{
"config": config,
},
),
(r"/", DemoHandler, {"config": config}),
# (r"/(.*)", StaticFileWithHeaderHandler, {"config": config, "index": 'index.html'}),
],
template_path = 'trap/web/',
compiled_template_cache=False)
def start(self):
evt_loop = asyncio.new_event_loop()
asyncio.set_event_loop(evt_loop)
# loop = tornado.ioloop.IOLoop.current()
logger.info(f"Listen on {self.config.ws_port}")
self.application.listen(self.config.ws_port)
loop = asyncio.get_event_loop()
task = evt_loop.create_task(self.prediction_forwarder())
evt_loop.run_forever()
async def prediction_forwarder(self):
logger.info("Starting prediction forwarder")
while True:
msg = await self.prediction_socket.recv_string()
logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg)
def run_ws_forwarder(config: Namespace):
router = WsRouter(config)
router.start()

83
trap/tracker.py Normal file
View file

@ -0,0 +1,83 @@
from argparse import Namespace
import numpy as np
import torch
import zmq
import cv2
from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
from deep_sort_realtime.deepsort_tracker import DeepSort
Detection = [int, int, int, int, float, int]
Detections = [Detection]
class Tracker:
def __init__(self, config: Namespace):
context = zmq.Context()
self.frame_sock = context.socket(zmq.SUB)
self.frame_sock.bind(config.zmq_frame_addr)
self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame
# TODO: config device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35)
self.model.to(self.device)
# Put the model in inference mode
self.model.eval()
# Get the transforms for the model's weights
self.preprocess = weights.transforms().to(self.device)
self.mot_tracker = DeepSort(max_age=5)
def track(self):
while True:
frame = self.frame_sock.recv()
detections = self.detect_persons(frame)
tracks = self.mot_tracker.update_tracks(detections, frame=frame)
# TODO: provide a track object that actually keeps history (unlike tracker)
def detect_persons(self, frame) -> Detections:
t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C)
t = t.permute(2, 0, 1)
batch = self.preprocess(t)[None, :].to(self.device)
# no_grad can be used on inference, should be slightly faster
with torch.no_grad():
predictions = self.model(batch)
prediction = predictions[0] # we feed only one frame at once
# TODO: check if we need e.g. cyclist
mask = prediction['labels'] == 1 # if we want more than one label: np.isin(prediction['labels'], [1,86])
scores = prediction['scores'][mask]
labels = prediction['labels'][mask]
boxes = prediction['boxes'][mask]
# TODO: introduce confidence and NMS supression: https://github.com/cfotache/pytorch_objectdetecttrack/blob/master/PyTorch_Object_Tracking.ipynb
# (which I _think_ we better do after filtering)
# alternatively look at Soft-NMS https://towardsdatascience.com/non-maximum-suppression-nms-93ce178e177c
# dets - a numpy array of detections in the format [[x1,y1,x2,y2,score, label],[x1,y1,x2,y2,score, label],...]
detections = np.array([np.append(bbox, [score, label]) for bbox, score, label in zip(boxes.cpu(), scores.cpu(), labels.cpu())])
detections = self.detect_persons_deepsort_wrapper(detections)
return detections
@classmethod
def detect_persons_deepsort_wrapper(detections):
"""make detect_persons() compatible with
deep_sort_realtime tracker by going from ltrb to ltwh and
different nesting
"""
return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections]
def run_tracker(config: Namespace):
router = Tracker(config)
router.track()

197
trap/web/index.html Normal file
View file

@ -0,0 +1,197 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Trajectory Prediction Browser Test</title>
<style>
body {
background: black;
}
#field {
background: white;
width: 100%;
height: 100%;
}
</style>
</head>
<body>
<canvas id="field" width="1500" height="1500">
</canvas>
<script>
// minified https://github.com/joewalnes/reconnecting-websocket
!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a});
</script>
<script>
// map the field to coordinates of our dummy tracker
const field_range = { x: [-10, 10], y: [-10, 10] }
// Create WebSocket connection.
const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`);
const prediction_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/prediction`);
let is_moving = false;
const fieldEl = document.getElementById('field');
let current_data = {}
// Listen for messages
prediction_socket.addEventListener("message", (event) => {
// console.log("Message from server ", event.data);
current_data = JSON.parse(event.data);
});
prediction_socket.addEventListener("open", (e) => appendAndSendPositions());
function getMousePos(canvas, evt) {
const rect = canvas.getBoundingClientRect();
return {
x: evt.clientX - rect.left,
y: evt.clientY - rect.top
};
}
function mouse_coordinates_to_position(coordinates) {
const x_range = field_range.x[1] - field_range.x[0]
const x = (coordinates.x / fieldEl.clientWidth) * x_range + field_range.x[0]
const y_range = field_range.y[1] - field_range.y[0]
const y = (coordinates.y / fieldEl.clientWidth) * y_range + field_range.y[0]
return { x: x, y: y }
}
function position_to_canvas_coordinate(position) {
const x_range = field_range.x[1] - field_range.x[0]
const y_range = field_range.y[1] - field_range.y[0]
const x = Array.isArray(position) ? position[0] : position.x;
const y = Array.isArray(position) ? position[1] : position.y;
return {
x: (x - field_range.x[0]) * fieldEl.width / x_range,
y: (y - field_range.y[0]) * fieldEl.width / y_range,
}
}
// helper function so we can spread
function coord_as_list(coord) {
return [coord.x, coord.y]
}
let tracker = {}
let person_counter = 0
class Person {
constructor(id) {
this.id = id;
this.history = [];
this.prediction = []
}
addToHistory(position) {
this.history.push(position);
}
}
let current_pos = null;
function appendAndSendPositions(){
if(is_moving && current_pos!==null){
// throttled update of tracker on movement
tracker[person_counter].addToHistory(current_pos);
}
for(const person_id in tracker){
if(person_id != person_counter){ // compare int/str
// fade out old tracks
tracker[person_id].history.shift()
if(!tracker[person_id].history.length){
delete tracker[person_id]
}
}
}
console.log(tracker)
trajectory_socket.send(JSON.stringify(tracker))
setTimeout(appendAndSendPositions, 200)
}
fieldEl.addEventListener('mousedown', (event) => {
tracker[person_counter] = new Person(person_counter);
is_moving = true;
const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos)
current_pos = position;
// tracker[person_counter].addToHistory(current_pos);
// trajectory_socket.send(JSON.stringify(tracker))
});
fieldEl.addEventListener('mousemove', (event) => {
if (!is_moving) return;
const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos)
current_pos = position;
// tracker[person_counter].addToHistory(current_pos);
// trajectory_socket.send(JSON.stringify(tracker))
});
document.addEventListener('mouseup', (e) => {
person_counter++;
is_moving = false;
})
const ctx = fieldEl.getContext("2d");
function drawFrame() {
ctx.clearRect(0, 0, fieldEl.width, fieldEl.height);
ctx.save();
for (let id in current_data) {
const person = current_data[id];
if (person.history.length > 1) {
const hist = structuredClone(person.history)
// draw current position:
ctx.beginPath()
ctx.arc(
...coord_as_list(position_to_canvas_coordinate(hist[hist.length - 1])),
5, //radius
0, 2 * Math.PI);
ctx.fill()
ctx.beginPath()
ctx.lineWidth = 3;
ctx.strokeStyle = "#325FA2";
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(hist.shift())));
for (const position of hist) {
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
}
ctx.stroke();
}
if(person.hasOwnProperty('predictions') && person.predictions.length > 0) {
// multiple predictions can be sampled
person.predictions.forEach((prediction, i) => {
ctx.beginPath()
ctx.lineWidth = i === 1 ? 3 : 0.2;
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
// start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
for (const position of prediction) {
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
}
ctx.stroke();
});
}
}
ctx.restore();
window.requestAnimationFrame(drawFrame);
}
window.requestAnimationFrame(drawFrame);
</script>
</body>
</html>