Compare commits

..

1 commit

Author SHA1 Message Date
Ruben van de Ven
51d6157af9 Refactor Trajectron++ to work as module and expose training command 2023-12-06 12:28:56 +01:00
12 changed files with 574 additions and 1193 deletions

931
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,17 @@
[tool.poetry] [tool.poetry]
name = "Trajectron-plus-plus" name = "Trajectron-plus-plus"
version = "0.1.0" version = "0.1.1"
description = "Predict trajectories for anomaly detection" description = "This repository contains the code for Trajectron++: Dynamically-Feasible Trajectory Forecasting With Heterogeneous Data by Tim Salzmann*, Boris Ivanovic*, Punarjay Chakravarty, and Marco Pavone (* denotes equal contribution)."
authors = ["Ruben van de Ven <git@rubenvandeven.com>"] authors = ["Ruben van de Ven <git@rubenvandeven.com>"]
readme = "README.md" readme = "README.md"
license = "MIT"
packages = [
{ include = "trajectron" },
]
[tool.poetry.scripts]
trajectron_train = "trajectron.train:main"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9,<3.12" python = "^3.9,<3.12"
@ -28,8 +36,6 @@ matplotlib = "^3.5"
#scikit-learn = "0.22.1" #scikit-learn = "0.22.1"
#filterpy = "^1.4.5" #filterpy = "^1.4.5"
tqdm = "^4.65.0" tqdm = "^4.65.0"
#ipywidgets = "^8.0.6"
#deep-sort-realtime = "^1.3.2"
scipy = "^1.11.3" scipy = "^1.11.3"
pandas = "^2.1.1" pandas = "^2.1.1"
orjson = "^3.9.7" orjson = "^3.9.7"
@ -41,7 +47,7 @@ notebook = "^7.0.4"
scikit-learn = "^1.3.1" scikit-learn = "^1.3.1"
seaborn = "^0.13.0" seaborn = "^0.13.0"
setuptools = "^68.2.2" setuptools = "^68.2.2"
tensorboard = "1.14" tensorboard = "^2.11"
tensorboardx = "^2.6.2.2" tensorboardx = "^2.6.2.2"

View file

@ -1,5 +0,0 @@
from trajpred import plumber
if __name__ == "__main__":
plumber.start()

View file

@ -4,7 +4,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--conf", parser.add_argument("--conf",
help="path to json config file for hyperparameters", help="path to json config file for hyperparameters",
type=str, type=str,
default='../config/config.json') default='./config/config.json')
parser.add_argument("--debug", parser.add_argument("--debug",
help="disable all disk writing processes.", help="disable all disk writing processes.",
@ -97,7 +97,7 @@ parser.add_argument('--no_edge_encoding',
parser.add_argument("--data_dir", parser.add_argument("--data_dir",
help="what dir to look in for data", help="what dir to look in for data",
type=str, type=str,
default='../experiments/processed') default='./experiments/processed')
parser.add_argument("--train_data_dict", parser.add_argument("--train_data_dict",
help="what file to load for training data", help="what file to load for training data",

View file

@ -3,7 +3,7 @@ from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import binary_dilation from scipy.ndimage import binary_dilation
from scipy.stats import gaussian_kde from scipy.stats import gaussian_kde
from trajectron.utils import prediction_output_to_trajectories from trajectron.utils import prediction_output_to_trajectories
import trajectron.visualization import trajectron.visualization as visualization
from matplotlib import pyplot as plt from matplotlib import pyplot as plt

View file

@ -9,14 +9,14 @@ import random
import pathlib import pathlib
import warnings import warnings
from tqdm import tqdm from tqdm import tqdm
import visualization import trajectron.visualization as visualization
import evaluation import trajectron.evaluation as evaluation
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from argument_parser import args from trajectron.argument_parser import args
from model.trajectron import Trajectron from trajectron.model.trajectron import Trajectron
from model.model_registrar import ModelRegistrar from trajectron.model.model_registrar import ModelRegistrar
from model.model_utils import cyclical_lr from trajectron.model.model_utils import cyclical_lr
from model.dataset import EnvironmentDataset, collate from trajectron.model.dataset import EnvironmentDataset, collate
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
# torch.autograd.set_detect_anomaly(True) # torch.autograd.set_detect_anomaly(True)

View file

View file

@ -1,153 +0,0 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--verbose',
'-v',
help="Increase verbosity. Add multiple times to increase further.",
action='count', default=0
)
parser.add_argument(
'--remote-log-addr',
help="Connect to a remote logger like cutelog. Specify the ip",
type=str,
)
parser.add_argument(
'--remote-log-port',
help="Connect to a remote logger like cutelog. Specify the port",
type=int,
default=19996
)
# parser.add_argument('--foo')
inference_parser = parser.add_argument_group('inference server')
connection_parser = parser.add_argument_group('connection')
inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference",
type=str, # TODO: make into Path
default='./experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3')
inference_parser.add_argument("--conf",
help="path to json config file for hyperparameters, relative to model_dir",
type=str,
default='config.json')
# Model Parameters (hyperparameters)
inference_parser.add_argument("--offline_scene_graph",
help="whether to precompute the scene graphs offline, options are 'no' and 'yes'",
type=str,
default='yes')
inference_parser.add_argument("--dynamic_edges",
help="whether to use dynamic edges or not, options are 'no' and 'yes'",
type=str,
default='yes')
inference_parser.add_argument("--edge_state_combine_method",
help="the method to use for combining edges of the same type",
type=str,
default='sum')
inference_parser.add_argument("--edge_influence_combine_method",
help="the method to use for combining edge influences",
type=str,
default='attention')
inference_parser.add_argument('--edge_addition_filter',
nargs='+',
help="what scaling to use for edges as they're created",
type=float,
default=[0.25, 0.5, 0.75, 1.0]) # We don't automatically pad left with 0.0, if you want a sharp
# and short edge addition, then you need to have a 0.0 at the
# beginning, e.g. [0.0, 1.0].
inference_parser.add_argument('--edge_removal_filter',
nargs='+',
help="what scaling to use for edges as they're removed",
type=float,
default=[1.0, 0.0]) # We don't automatically pad right with 0.0, if you want a sharp drop off like
# the default, then you need to have a 0.0 at the end.
inference_parser.add_argument('--incl_robot_node',
help="whether to include a robot node in the graph or simply model all agents",
action='store_true')
inference_parser.add_argument('--map_encoding',
help="Whether to use map encoding or not",
action='store_true')
inference_parser.add_argument('--no_edge_encoding',
help="Whether to use neighbors edge encoding",
action='store_true')
inference_parser.add_argument('--batch_size',
help='training batch size',
type=int,
default=256)
inference_parser.add_argument('--k_eval',
help='how many samples to take during evaluation',
type=int,
default=25)
# Data Parameters
inference_parser.add_argument("--eval_data_dict",
help="what file to load for evaluation data (WHEN NOT USING LIVE DATA)",
type=str,
default='./experiments/processed/eth_test.pkl')
inference_parser.add_argument("--output_dir",
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
type=str,
default='../experiments/pedestrians/OUT/test_inference')
# inference_parser.add_argument('--device',
# help='what device to perform training on',
# type=str,
# default='cuda:0')
inference_parser.add_argument("--eval_device",
help="what device to use during inference",
type=str,
default="cpu")
inference_parser.add_argument('--seed',
help='manual seed to use, default is 123',
type=int,
default=123)
# Internal connections.
connection_parser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds/traj")
connection_parser.add_argument('--zmq-camera-stream-addr',
help='Manually specity communication addr for the camera stream messages',
type=str,
default="ipc:///tmp/feeds/img")
connection_parser.add_argument('--zmq-prediction-addr',
help='Manually specity communication addr for the prediction messages',
type=str,
default="ipc:///tmp/feeds/preds")
connection_parser.add_argument('--ws-port',
help='Port to listen for incomming websocket connections. Also serves the testing html-page.',
type=int,
default=8888)
connection_parser.add_argument('--bypass-prediction',
help='For debugging purpose: websocket input immediately to output',
action='store_true')

View file

@ -1,40 +0,0 @@
import logging
from logging.handlers import SocketHandler
from multiprocessing import Process, Queue
from trajpred.config import parser
from trajpred.prediction_server import InferenceServer, run_inference_server
from trajpred.socket_forwarder import run_ws_forwarder
logger = logging.getLogger("trajpred.plumbing")
def start():
args = parser.parse_args()
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
logging.basicConfig(
level=loglevel,
)
if args.remote_log_addr:
logging.captureWarnings(True)
root_logger = logging.getLogger()
root_logger.setLevel(logging.NOTSET) # to send all records to cutelog
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
root_logger.addHandler(socket_handler)
# instantiating process with arguments
procs = [
Process(target=run_ws_forwarder, args=(args,))
]
if not args.bypass_prediction:
procs.append(
Process(target=run_inference_server, args=(args,)),
)
logger.info("start")
for proc in procs:
proc.start()
for proc in procs:
proc.join()

View file

@ -1,270 +0,0 @@
# adapted from Trajectron++ online_server.py
import logging
from multiprocessing import Queue
import os
import time
import json
import torch
import dill
import random
import pathlib
import numpy as np
from trajectron.utils import prediction_output_to_trajectories
from trajectron.model.online.online_trajectron import OnlineTrajectron
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.environment import Environment, Scene
import matplotlib.pyplot as plt
import zmq
logger = logging.getLogger("trajpred.inference")
# if not torch.cuda.is_available() or self.config.device == 'cpu':
# self.config.device = torch.device('cpu')
# else:
# if torch.cuda.device_count() == 1:
# # If you have CUDA_VISIBLE_DEVICES set, which you should,
# # then this will prevent leftover flag arguments from
# # messing with the device allocation.
# self.config.device = 'cuda:0'
# self.config.device = torch.device(self.config.device)
def create_online_env(env, hyperparams, scene_idx, init_timestep):
test_scene = env.scenes[scene_idx]
online_scene = Scene(timesteps=init_timestep + 1,
map=test_scene.map,
dt=test_scene.dt)
online_scene.nodes = test_scene.get_nodes_clipped_at_time(
timesteps=np.arange(init_timestep - hyperparams['maximum_history_length'],
init_timestep + 1),
state=hyperparams['state'])
online_scene.robot = test_scene.robot
online_scene.calculate_scene_graph(attention_radius=env.attention_radius,
edge_addition_filter=hyperparams['edge_addition_filter'],
edge_removal_filter=hyperparams['edge_removal_filter'])
return Environment(node_type_list=env.node_type_list,
standardization=env.standardization,
scenes=[online_scene],
attention_radius=env.attention_radius,
robot_type=env.robot_type)
def get_maps_for_input(input_dict, scene, hyperparams):
scene_maps = list()
scene_pts = list()
heading_angles = list()
patch_sizes = list()
nodes_with_maps = list()
for node in input_dict:
if node.type in hyperparams['map_encoder']:
x = input_dict[node]
me_hyp = hyperparams['map_encoder'][node.type]
if 'heading_state_index' in me_hyp:
heading_state_index = me_hyp['heading_state_index']
# We have to rotate the map in the opposit direction of the agent to match them
if type(heading_state_index) is list: # infer from velocity or heading vector
heading_angle = -np.arctan2(x[-1, heading_state_index[1]],
x[-1, heading_state_index[0]]) * 180 / np.pi
else:
heading_angle = -x[-1, heading_state_index] * 180 / np.pi
else:
heading_angle = None
scene_map = scene.map[node.type]
map_point = x[-1, :2]
patch_size = hyperparams['map_encoder'][node.type]['patch_size']
scene_maps.append(scene_map)
scene_pts.append(map_point)
heading_angles.append(heading_angle)
patch_sizes.append(patch_size)
nodes_with_maps.append(node)
if heading_angles[0] is None:
heading_angles = None
else:
heading_angles = torch.Tensor(heading_angles)
maps = scene_maps[0].get_cropped_maps_from_scene_map_batch(scene_maps,
scene_pts=torch.Tensor(scene_pts),
patch_size=patch_sizes[0],
rotation=heading_angles)
maps_dict = {node: maps[[i]] for i, node in enumerate(nodes_with_maps)}
return maps_dict
class InferenceServer:
def __init__(self, config: dict):
self.config = config
context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr)
print(self.prediction_socket)
def run(self):
if self.config.seed is not None:
random.seed(self.config.seed)
np.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.config.seed)
# Choose one of the model directory names under the experiment/*/models folders.
# Possibilities are 'vel_ee', 'int_ee', 'int_ee_me', or 'robot'
# model_dir = os.path.join(self.config.log_dir, 'int_ee')
# model_dir = 'models/models_04_Oct_2023_21_04_48_eth_vel_ar3'
# Load hyperparameters from json
config_file = os.path.join(self.config.model_dir, self.config.conf)
if not os.path.exists(config_file):
raise ValueError('Config json not found!')
with open(config_file, 'r') as conf_json:
hyperparams = json.load(conf_json)
# Add hyperparams from arguments
hyperparams['dynamic_edges'] = self.config.dynamic_edges
hyperparams['edge_state_combine_method'] = self.config.edge_state_combine_method
hyperparams['edge_influence_combine_method'] = self.config.edge_influence_combine_method
hyperparams['edge_addition_filter'] = self.config.edge_addition_filter
hyperparams['edge_removal_filter'] = self.config.edge_removal_filter
hyperparams['batch_size'] = self.config.batch_size
hyperparams['k_eval'] = self.config.k_eval
hyperparams['offline_scene_graph'] = self.config.offline_scene_graph
hyperparams['incl_robot_node'] = self.config.incl_robot_node
hyperparams['edge_encoding'] = not self.config.no_edge_encoding
hyperparams['use_map_encoding'] = self.config.map_encoding
output_save_dir = os.path.join(self.config.output_dir, 'pred_figs')
pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)
with open(self.config.eval_data_dict, 'rb') as f:
eval_env = dill.load(f, encoding='latin1')
if eval_env.robot_type is None and hyperparams['incl_robot_node']:
eval_env.robot_type = eval_env.NodeType[0] # TODO: Make more general, allow the user to specify?
for scene in eval_env.scenes:
scene.add_robot_from_nodes(eval_env.robot_type)
logger.info('Loaded data from %s' % (self.config.eval_data_dict,))
# Creating a dummy environment with a single scene that contains information about the world.
# When using this code, feel free to use whichever scene index or initial timestep you wish.
scene_idx = 0
# You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1],
# so that you can immediately start incremental inference from the 3rd timestep onwards.
init_timestep = 1
eval_scene = eval_env.scenes[scene_idx]
online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep)
model_registrar = ModelRegistrar(self.config.model_dir, self.config.eval_device)
model_registrar.load_models(iter_num=100)
trajectron = OnlineTrajectron(model_registrar,
hyperparams,
self.config.eval_device)
# If you want to see what different robot futures do to the predictions, uncomment this line as well as
# related "... += adjustment" lines below.
# adjustment = np.stack([np.arange(13)/float(i*2.0) for i in range(6, 12)], axis=1)
# Here's how you'd incrementally run the model, e.g. with streaming data.
trajectron.set_environment(online_env, init_timestep)
for timestep in range(init_timestep + 1, eval_scene.timesteps):
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
maps = None
if hyperparams['use_map_encoding']:
maps = get_maps_for_input(input_dict, eval_scene, hyperparams)
robot_present_and_future = None
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
robot_present_and_future = eval_scene.robot.get(np.array([timestep,
timestep + hyperparams['prediction_horizon']]),
hyperparams['state'][eval_scene.robot.type],
padding=0.0)
robot_present_and_future = np.stack([robot_present_and_future, robot_present_and_future], axis=0)
# robot_present_and_future += adjustment
start = time.time()
dists, preds = trajectron.incremental_forward(input_dict,
maps,
prediction_horizon=6,
num_samples=51,
robot_present_and_future=robot_present_and_future,
full_dist=True)
end = time.time()
logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
1. / (end - start), len(trajectron.nodes),
trajectron.scene_graph.get_num_edges()))
# unsure what this bit from online_prediction.py does:
# detailed_preds_dict = dict()
# for node in eval_scene.nodes:
# if node in preds:
# detailed_preds_dict[node] = preds[node]
#adapted from trajectron.visualization
# prediction_dict provides the actual predictions
# histories_dict provides the trajectory used for prediction
# futures_dict is the Ground Truth, which is unvailable in an online setting
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
eval_scene.dt,
hyperparams['maximum_history_length'],
hyperparams['prediction_horizon']
)
assert(len(prediction_dict.keys()) <= 1)
if len(prediction_dict.keys()) == 0:
return
ts_key = list(prediction_dict.keys())[0]
prediction_dict = prediction_dict[ts_key]
histories_dict = histories_dict[ts_key]
futures_dict = futures_dict[ts_key]
response = {}
for node in histories_dict:
history = histories_dict[node]
# future = futures_dict[node]
predictions = prediction_dict[node]
if np.isnan(history[-1]).any():
continue
response[node.id] = {
'id': node.id,
'history': history.tolist(),
'predictions': predictions[0].tolist() # use batch 0
}
data = json.dumps(response)
self.prediction_socket.send_string(data)
# time.sleep(1)
# print(prediction_dict)
# print(histories_dict)
# print(futures_dict)
def run_inference_server(config):
s = InferenceServer(config)
s.run()

View file

@ -1,162 +0,0 @@
from argparse import Namespace
import asyncio
import logging
from typing import Set, Union, Dict, Any
from typing_extensions import Self
from urllib.error import HTTPError
import tornado.ioloop
import tornado.web
import tornado.websocket
import zmq
import zmq.asyncio
logger = logging.getLogger("trajpred.forwarder")
class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler):
def initialize(self, zmq_socket: zmq.asyncio.Socket):
self.zmq_socket = zmq_socket
async def on_message(self, message):
logger.debug(f"recieve msg")
try:
await self.zmq_socket.send_string(message)
# msg = json.loads(message)
except Exception as e:
# self.send({'alert': 'Invalid request: {}'.format(e)})
logger.exception(e)
# self.write_message(u"You said: " + message)
def open(self, p=None):
logger.info(f"connected {self.request.remote_ip}")
# client disconnected
def on_close(self):
logger.info(f"Client disconnected: {self.request.remote_ip}")
class WebSocketPredictionHandler(tornado.websocket.WebSocketHandler):
connections: Set[Self] = set()
def initialize(self, config):
self.config = config
def on_message(self, message):
logger.warning(f"Receiving message on send-only ws handler: {message}")
def open(self, p=None):
logger.info(f"Prediction WS connected {self.request.remote_ip}")
self.__class__.connections.add(self)
# client disconnected
def on_close(self):
self.__class__.rmConnection(self)
logger.info(f"Client disconnected: {self.request.remote_ip}")
@classmethod
def rmConnection(cls, client):
if client not in cls.connections:
return
cls.connections.remove(client)
@classmethod
def hasConnection(cls, client):
return client in cls.connections
@classmethod
def write_to_clients(cls, msg: Union[bytes, str, Dict[str, Any]]):
if msg is None:
logger.critical("Tried to send 'none'")
return
toRemove = []
for client in cls.connections:
try:
client.write_message(msg)
except tornado.websocket.WebSocketClosedError as e:
logger.warning(f"Not properly closed websocket connection")
toRemove.append(client) # If we remove it here from the set we get an exception about changing set size during iteration
for client in toRemove:
cls.rmConnection(client)
class DemoHandler(tornado.web.RequestHandler):
def initialize(self, config: Namespace):
self.config = config
def get(self):
self.render("index.html", ws_port=self.config.ws_port)
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
def set_extra_headers(self, path):
"""For subclass to add extra headers to the response"""
if path[-5:] == ".html":
self.set_header("Access-Control-Allow-Origin", "*")
if path[-4:] == ".svg":
self.set_header("Content-Type", "image/svg+xml")
class WsRouter:
def __init__(self, config: Namespace):
self.config = config
context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.connect(config.zmq_prediction_addr)
self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.application = tornado.web.Application(
[
(
r"/ws/trajectory",
WebSocketTrajectoryHandler,
{
"zmq_socket": self.trajectory_socket
},
),
(
r"/ws/prediction",
WebSocketPredictionHandler,
{
"config": config,
},
),
(r"/", DemoHandler, {"config": config}),
# (r"/(.*)", StaticFileWithHeaderHandler, {"config": config, "index": 'index.html'}),
],
template_path = 'trajpred/web/',
compiled_template_cache=False)
def start(self):
evt_loop = asyncio.new_event_loop()
asyncio.set_event_loop(evt_loop)
# loop = tornado.ioloop.IOLoop.current()
logger.info(f"Listen on {self.config.ws_port}")
self.application.listen(self.config.ws_port)
loop = asyncio.get_event_loop()
task = evt_loop.create_task(self.prediction_forwarder())
evt_loop.run_forever()
async def prediction_forwarder(self):
logger.info("Starting prediction forwarder")
while True:
msg = await self.prediction_socket.recv_string()
logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg)
def run_ws_forwarder(config: Namespace):
router = WsRouter(config)
router.start()

View file

@ -1,170 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Trajectory Prediction Browser Test</title>
<style>
body {
background: black;
}
#field {
background: white;
width: 100%;
height: 100%;
}
</style>
</head>
<body>
<canvas id="field" width="1500" height="1500">
</canvas>
<script>
// minified https://github.com/joewalnes/reconnecting-websocket
!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a});
</script>
<script>
// map the field to coordinates of our dummy tracker
const field_range = { x: [-10, 10], y: [-10, 10] }
// Create WebSocket connection.
const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`);
const prediction_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/prediction`);
let is_moving = false;
const fieldEl = document.getElementById('field');
let current_data = {}
// Listen for messages
prediction_socket.addEventListener("message", (event) => {
// console.log("Message from server ", event.data);
current_data = JSON.parse(event.data);
});
function getMousePos(canvas, evt) {
const rect = canvas.getBoundingClientRect();
return {
x: evt.clientX - rect.left,
y: evt.clientY - rect.top
};
}
function mouse_coordinates_to_position(coordinates) {
const x_range = field_range.x[1] - field_range.x[0]
const x = (coordinates.x / fieldEl.clientWidth) * x_range + field_range.x[0]
const y_range = field_range.y[1] - field_range.y[0]
const y = (coordinates.y / fieldEl.clientWidth) * y_range + field_range.y[0]
return { x: x, y: y }
}
function position_to_canvas_coordinate(position) {
const x_range = field_range.x[1] - field_range.x[0]
const y_range = field_range.y[1] - field_range.y[0]
const x = Array.isArray(position) ? position[0] : position.x;
const y = Array.isArray(position) ? position[1] : position.y;
return {
x: (x - field_range.x[0]) * fieldEl.width / x_range,
y: (y - field_range.y[0]) * fieldEl.width / y_range,
}
}
// helper function so we can spread
function coord_as_list(coord) {
return [coord.x, coord.y]
}
let tracker = {}
let person_counter = 0
class Person {
constructor(id) {
this.id = id;
this.history = [];
this.prediction = []
}
addToHistory(position) {
this.history.push(position);
}
}
fieldEl.addEventListener('mousedown', (event) => {
person_counter++;
tracker[person_counter] = new Person(person_counter);
is_moving = true;
const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos)
tracker[person_counter].addToHistory(position);
trajectory_socket.send(JSON.stringify(tracker))
});
fieldEl.addEventListener('mousemove', (event) => {
if (!is_moving) return;
const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos)
tracker[person_counter].addToHistory(position);
trajectory_socket.send(JSON.stringify(tracker))
});
document.addEventListener('mouseup', (e) => {
is_moving = false;
tracker = {}
})
const ctx = fieldEl.getContext("2d");
function drawFrame() {
ctx.clearRect(0, 0, fieldEl.width, fieldEl.height);
ctx.save();
for (let id in current_data) {
const person = current_data[id];
if (person.history.length > 1) {
const hist = structuredClone(person.history)
// draw current position:
ctx.beginPath()
ctx.arc(
...coord_as_list(position_to_canvas_coordinate(hist[hist.length - 1])),
5, //radius
0, 2 * Math.PI);
ctx.fill()
ctx.beginPath()
ctx.lineWidth = 3;
ctx.strokeStyle = "#325FA2";
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(hist.shift())));
for (const position of hist) {
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
}
ctx.stroke();
}
if(person.hasOwnProperty('predictions') && person.predictions.length > 0) {
// multiple predictions can be sampled
person.predictions.forEach((prediction, i) => {
ctx.beginPath()
ctx.lineWidth = i === 1 ? 3 : 0.2;
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
// start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
for (const position of prediction) {
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
}
ctx.stroke();
});
}
}
ctx.restore();
window.requestAnimationFrame(drawFrame);
}
window.requestAnimationFrame(drawFrame);
</script>
</body>
</html>