Compare commits

...

10 commits

Author SHA1 Message Date
Ruben van de Ven
a4a4990035 Loglvl 2023-10-12 09:04:00 +02:00
Ruben van de Ven
99a54198fb Support remote logging 2023-10-12 09:03:35 +02:00
Ruben van de Ven
ffcf1c26c1 Support remote logging 2023-10-12 09:03:11 +02:00
Ruben van de Ven
ee57604b30 Render predictions to browser 2023-10-11 16:35:15 +02:00
Ruben van de Ven
9ba283ca9b Trajectory prediction - test in browser 2023-10-11 13:58:09 +02:00
Ruben van de Ven
4fa3ce95ee Trajectron as a module while supporting old models 2023-10-09 21:06:36 +02:00
Ruben van de Ven
dfa1d43f2e Change imports to support usage as module
Note that this does require rerunning the `process_data.py` scripts in the example folders so that the dill-files are updated.
2023-10-09 20:27:29 +02:00
Ruben van de Ven
f4d860907c Fix compatibility with numpy > 1.20 2023-10-04 21:16:02 +02:00
Ruben van de Ven
d182787c83 Fix compatibility with python 3.10 2023-10-04 21:14:46 +02:00
Ruben van de Ven
d4261f1c3d pyenv and poetry requirements 2023-10-04 21:13:44 +02:00
29 changed files with 4124 additions and 50 deletions

1
.python-version Normal file
View file

@ -0,0 +1 @@
3.10.4

View file

@ -7,11 +7,11 @@ import torch
import numpy as np import numpy as np
import pandas as pd import pandas as pd
sys.path.append("../../trajectron") sys.path.append("../../")
from tqdm import tqdm from tqdm import tqdm
from model.model_registrar import ModelRegistrar from trajectron.model.model_registrar import ModelRegistrar
from model.trajectron import Trajectron from trajectron.model.trajectron import Trajectron
import evaluation import trajectron.evaluation as evaluation
seed = 0 seed = 0
np.random.seed(seed) np.random.seed(seed)

View file

@ -4,10 +4,10 @@ import numpy as np
import pandas as pd import pandas as pd
import dill import dill
sys.path.append("../../trajectron") sys.path.append("../../")
from environment import Environment, Scene, Node from trajectron.environment import Environment, Scene, Node
from utils import maybe_makedirs from trajectron.utils import maybe_makedirs
from environment import derivative_of from trajectron.environment import derivative_of
desired_max_time = 100 desired_max_time = 100
pred_indices = [2, 3] pred_indices = [2, 3]

3211
poetry.lock generated Normal file

File diff suppressed because it is too large Load diff

50
pyproject.toml Normal file
View file

@ -0,0 +1,50 @@
[tool.poetry]
name = "Trajectron-plus-plus"
version = "0.1.0"
description = "Predict trajectories for anomaly detection"
authors = ["Ruben van de Ven <git@rubenvandeven.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.9,<3.12"
numpy = "^1.24.3"
opencv-python = "^4.7.0.72"
ipykernel = "^6.22.0"
torch = [
{ version="1.12.1" },
# { url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp38-cp38-linux_x86_64.whl", markers = "python_version ~= '3.8' and sys_platform == 'linux'" },
{ url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310-linux_x86_64.whl", markers = "python_version ~= '3.10' and sys_platform == 'linux'" },
]
#torchvision = [
# { version="0.13.1" },
# { url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp38-cp38-linux_x86_64.whl", markers = "python_version ~= '3.8' and sys_platform == 'linux'" },
# { url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp310-cp310-linux_x86_64.whl", markers = "python_version ~= '3.10' and sys_platform == 'linux'" },
#]
#av = "^10.0.0"
matplotlib = "^3.5"
#numba = "^0.57.0"
#scikit-image = "^0.20.0"
#scikit-learn = "0.22.1"
#filterpy = "^1.4.5"
tqdm = "^4.65.0"
#ipywidgets = "^8.0.6"
#deep-sort-realtime = "^1.3.2"
scipy = "^1.11.3"
pandas = "^2.1.1"
orjson = "^3.9.7"
nuscenes-devkit = "^1.1.11"
pyquaternion = "^0.9.9"
dill = "^0.3.7"
ncls = "^0.0.68"
notebook = "^7.0.4"
scikit-learn = "^1.3.1"
seaborn = "^0.13.0"
setuptools = "^68.2.2"
tensorboard = "1.14"
tensorboardx = "^2.6.2.2"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

5
run_server.py Normal file
View file

@ -0,0 +1,5 @@
from trajpred import plumber
if __name__ == "__main__":
plumber.start()

View file

@ -5,13 +5,12 @@ import torch
import dill import dill
import random import random
import pathlib import pathlib
import evaluation
import numpy as np import numpy as np
import visualization as vis import trajectron.visualization as vis
from argument_parser import args from trajectron.argument_parser import args
from model.online.online_trajectron import OnlineTrajectron from trajectron.model.online.online_trajectron import OnlineTrajectron
from model.model_registrar import ModelRegistrar from trajectron.model.model_registrar import ModelRegistrar
from environment import Environment, Scene from trajectron.environment import Environment, Scene
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
if not torch.cuda.is_available() or args.device == 'cpu': if not torch.cuda.is_available() or args.device == 'cpu':

View file

@ -1,6 +1,7 @@
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from collections import Sequence, OrderedDict from collections.abc import Sequence
from collections import OrderedDict
class RingBuffer(Sequence): class RingBuffer(Sequence):

View file

@ -1,6 +1,6 @@
import torch import torch
import numpy as np import numpy as np
from model.dataset.homography_warper import get_rotation_matrix2d, warp_affine_crop from trajectron.model.dataset.homography_warper import get_rotation_matrix2d, warp_affine_crop
class Map(object): class Map(object):

View file

@ -1,7 +1,7 @@
import random import random
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from environment import DoubleHeaderNumpyArray from trajectron.environment import DoubleHeaderNumpyArray
from ncls import NCLS from ncls import NCLS

View file

@ -135,10 +135,10 @@ class TemporalSceneGraph(object):
position_cube = np.full((total_timesteps, N, 2), np.nan) position_cube = np.full((total_timesteps, N, 2), np.nan)
adj_cube = np.zeros((total_timesteps, N, N), dtype=np.int8) adj_cube = np.zeros((total_timesteps, N, N), dtype=np.int8)
dist_cube = np.zeros((total_timesteps, N, N), dtype=np.float) dist_cube = np.zeros((total_timesteps, N, N), dtype=float)
node_type_mat = np.zeros((N, N), dtype=np.int8) node_type_mat = np.zeros((N, N), dtype=np.int8)
node_attention_mat = np.zeros((N, N), dtype=np.float) node_attention_mat = np.zeros((N, N), dtype=float)
for node_idx, node in enumerate(nodes): for node_idx, node in enumerate(nodes):
if online: if online:

View file

@ -2,8 +2,8 @@ import numpy as np
from scipy.interpolate import RectBivariateSpline from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import binary_dilation from scipy.ndimage import binary_dilation
from scipy.stats import gaussian_kde from scipy.stats import gaussian_kde
from utils import prediction_output_to_trajectories from trajectron.utils import prediction_output_to_trajectories
import visualization import trajectron.visualization
from matplotlib import pyplot as plt from matplotlib import pyplot as plt

View file

@ -1,7 +1,7 @@
import torch import torch
import torch.distributions as td import torch.distributions as td
import numpy as np import numpy as np
from model.model_utils import ModeKeys from trajectron.model.model_utils import ModeKeys
class DiscreteLatent(object): class DiscreteLatent(object):

View file

@ -1,7 +1,7 @@
import torch import torch
import torch.distributions as td import torch.distributions as td
import numpy as np import numpy as np
from model.model_utils import to_one_hot from trajectron.model.model_utils import to_one_hot
class GMM2D(td.Distribution): class GMM2D(td.Distribution):

View file

@ -1,4 +1,4 @@
from model.dynamics import Dynamic from trajectron.model.dynamics import Dynamic
class Linear(Dynamic): class Linear(Dynamic):

View file

@ -1,7 +1,7 @@
import torch import torch
from model.dynamics import Dynamic from trajectron.model.dynamics import Dynamic
from utils import block_diag from trajectron.utils import block_diag
from model.components import GMM2D from trajectron.model.components import GMM2D
class SingleIntegrator(Dynamic): class SingleIntegrator(Dynamic):

View file

@ -1,8 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from model.dynamics import Dynamic from trajectron.model.dynamics import Dynamic
from utils import block_diag from trajectron.utils import block_diag
from model.components import GMM2D from trajectron.model.components import GMM2D
class Unicycle(Dynamic): class Unicycle(Dynamic):

View file

@ -2,10 +2,10 @@ import warnings
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from model.components import * from trajectron.model.components import *
from model.model_utils import * from trajectron.model.model_utils import *
import model.dynamics as dynamic_module import trajectron.model.dynamics as dynamic_module
from environment.scene_graph import DirectedEdge from trajectron.environment.scene_graph import DirectedEdge
class MultimodalGenerativeCVAE(object): class MultimodalGenerativeCVAE(object):

View file

@ -1,11 +1,23 @@
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import pickle
def get_model_device(model): def get_model_device(model):
return next(model.parameters()).device return next(model.parameters()).device
class PickleModuleCompatibility:
'''
Migrating Trajectron++ to a module structure
while maintaining compatibility with models generated
before the migration
'''
class Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'model' or module[:6] == 'model.':
module = 'trajectron.' + module
return super().find_class(module, name)
class ModelRegistrar(nn.Module): class ModelRegistrar(nn.Module):
def __init__(self, model_dir, device): def __init__(self, model_dir, device):
@ -65,7 +77,7 @@ class ModelRegistrar(nn.Module):
print('') print('')
print('Loading from ' + save_path) print('Loading from ' + save_path)
self.model_dict = torch.load(save_path, map_location=self.device) self.model_dict = torch.load(save_path, map_location=self.device, pickle_module=PickleModuleCompatibility)
print('Loaded!') print('Loaded!')
print('') print('')

View file

@ -4,13 +4,13 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from collections import defaultdict, Counter from collections import defaultdict, Counter
from model.components import * from trajectron.model.components import *
from model.model_utils import * from trajectron.model.model_utils import *
from model.dataset import get_relative_robot_traj from trajectron.model.dataset import get_relative_robot_traj
import model.dynamics as dynamic_module import trajectron.model.dynamics as dynamic_module
from model.mgcvae import MultimodalGenerativeCVAE from trajectron.model.mgcvae import MultimodalGenerativeCVAE
from environment.scene_graph import DirectedEdge from trajectron.environment.scene_graph import DirectedEdge
from environment.node_type import NodeType from trajectron.environment.node_type import NodeType
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE): class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):

View file

@ -1,10 +1,10 @@
import torch import torch
import numpy as np import numpy as np
from collections import Counter from collections import Counter
from model.trajectron import Trajectron from trajectron.model.trajectron import Trajectron
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE from trajectron.model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
from model.model_utils import ModeKeys from trajectron.model.model_utils import ModeKeys
from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of from trajectron.environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
class OnlineTrajectron(Trajectron): class OnlineTrajectron(Trajectron):

View file

@ -1,7 +1,7 @@
import torch import torch
import numpy as np import numpy as np
from model.mgcvae import MultimodalGenerativeCVAE from trajectron.model.mgcvae import MultimodalGenerativeCVAE
from model.dataset import get_timesteps_data, restore from trajectron.model.dataset import get_timesteps_data, restore
class Trajectron(object): class Trajectron(object):

View file

@ -1,4 +1,4 @@
from utils import prediction_output_to_trajectories from trajectron.utils import prediction_output_to_trajectories
from scipy import linalg from scipy import linalg
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.patches as patches import matplotlib.patches as patches

0
trajpred/__init__.py Normal file
View file

153
trajpred/config.py Normal file
View file

@ -0,0 +1,153 @@
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--verbose',
'-v',
help="Increase verbosity. Add multiple times to increase further.",
action='count', default=0
)
parser.add_argument(
'--remote-log-addr',
help="Connect to a remote logger like cutelog. Specify the ip",
type=str,
)
parser.add_argument(
'--remote-log-port',
help="Connect to a remote logger like cutelog. Specify the port",
type=int,
default=19996
)
# parser.add_argument('--foo')
inference_parser = parser.add_argument_group('inference server')
connection_parser = parser.add_argument_group('connection')
inference_parser.add_argument("--model_dir",
help="directory with the model to use for inference",
type=str, # TODO: make into Path
default='./experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3')
inference_parser.add_argument("--conf",
help="path to json config file for hyperparameters, relative to model_dir",
type=str,
default='config.json')
# Model Parameters (hyperparameters)
inference_parser.add_argument("--offline_scene_graph",
help="whether to precompute the scene graphs offline, options are 'no' and 'yes'",
type=str,
default='yes')
inference_parser.add_argument("--dynamic_edges",
help="whether to use dynamic edges or not, options are 'no' and 'yes'",
type=str,
default='yes')
inference_parser.add_argument("--edge_state_combine_method",
help="the method to use for combining edges of the same type",
type=str,
default='sum')
inference_parser.add_argument("--edge_influence_combine_method",
help="the method to use for combining edge influences",
type=str,
default='attention')
inference_parser.add_argument('--edge_addition_filter',
nargs='+',
help="what scaling to use for edges as they're created",
type=float,
default=[0.25, 0.5, 0.75, 1.0]) # We don't automatically pad left with 0.0, if you want a sharp
# and short edge addition, then you need to have a 0.0 at the
# beginning, e.g. [0.0, 1.0].
inference_parser.add_argument('--edge_removal_filter',
nargs='+',
help="what scaling to use for edges as they're removed",
type=float,
default=[1.0, 0.0]) # We don't automatically pad right with 0.0, if you want a sharp drop off like
# the default, then you need to have a 0.0 at the end.
inference_parser.add_argument('--incl_robot_node',
help="whether to include a robot node in the graph or simply model all agents",
action='store_true')
inference_parser.add_argument('--map_encoding',
help="Whether to use map encoding or not",
action='store_true')
inference_parser.add_argument('--no_edge_encoding',
help="Whether to use neighbors edge encoding",
action='store_true')
inference_parser.add_argument('--batch_size',
help='training batch size',
type=int,
default=256)
inference_parser.add_argument('--k_eval',
help='how many samples to take during evaluation',
type=int,
default=25)
# Data Parameters
inference_parser.add_argument("--eval_data_dict",
help="what file to load for evaluation data (WHEN NOT USING LIVE DATA)",
type=str,
default='./experiments/processed/eth_test.pkl')
inference_parser.add_argument("--output_dir",
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
type=str,
default='../experiments/pedestrians/OUT/test_inference')
# inference_parser.add_argument('--device',
# help='what device to perform training on',
# type=str,
# default='cuda:0')
inference_parser.add_argument("--eval_device",
help="what device to use during inference",
type=str,
default="cpu")
inference_parser.add_argument('--seed',
help='manual seed to use, default is 123',
type=int,
default=123)
# Internal connections.
connection_parser.add_argument('--zmq-trajectory-addr',
help='Manually specity communication addr for the trajectory messages',
type=str,
default="ipc:///tmp/feeds/traj")
connection_parser.add_argument('--zmq-camera-stream-addr',
help='Manually specity communication addr for the camera stream messages',
type=str,
default="ipc:///tmp/feeds/img")
connection_parser.add_argument('--zmq-prediction-addr',
help='Manually specity communication addr for the prediction messages',
type=str,
default="ipc:///tmp/feeds/preds")
connection_parser.add_argument('--ws-port',
help='Port to listen for incomming websocket connections. Also serves the testing html-page.',
type=int,
default=8888)
connection_parser.add_argument('--bypass-prediction',
help='For debugging purpose: websocket input immediately to output',
action='store_true')

40
trajpred/plumber.py Normal file
View file

@ -0,0 +1,40 @@
import logging
from logging.handlers import SocketHandler
from multiprocessing import Process, Queue
from trajpred.config import parser
from trajpred.prediction_server import InferenceServer, run_inference_server
from trajpred.socket_forwarder import run_ws_forwarder
logger = logging.getLogger("trajpred.plumbing")
def start():
args = parser.parse_args()
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
logging.basicConfig(
level=loglevel,
)
if args.remote_log_addr:
logging.captureWarnings(True)
root_logger = logging.getLogger()
root_logger.setLevel(logging.NOTSET) # to send all records to cutelog
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
root_logger.addHandler(socket_handler)
# instantiating process with arguments
procs = [
Process(target=run_ws_forwarder, args=(args,))
]
if not args.bypass_prediction:
procs.append(
Process(target=run_inference_server, args=(args,)),
)
logger.info("start")
for proc in procs:
proc.start()
for proc in procs:
proc.join()

View file

@ -0,0 +1,270 @@
# adapted from Trajectron++ online_server.py
import logging
from multiprocessing import Queue
import os
import time
import json
import torch
import dill
import random
import pathlib
import numpy as np
from trajectron.utils import prediction_output_to_trajectories
from trajectron.model.online.online_trajectron import OnlineTrajectron
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.environment import Environment, Scene
import matplotlib.pyplot as plt
import zmq
logger = logging.getLogger("trajpred.inference")
# if not torch.cuda.is_available() or self.config.device == 'cpu':
# self.config.device = torch.device('cpu')
# else:
# if torch.cuda.device_count() == 1:
# # If you have CUDA_VISIBLE_DEVICES set, which you should,
# # then this will prevent leftover flag arguments from
# # messing with the device allocation.
# self.config.device = 'cuda:0'
# self.config.device = torch.device(self.config.device)
def create_online_env(env, hyperparams, scene_idx, init_timestep):
test_scene = env.scenes[scene_idx]
online_scene = Scene(timesteps=init_timestep + 1,
map=test_scene.map,
dt=test_scene.dt)
online_scene.nodes = test_scene.get_nodes_clipped_at_time(
timesteps=np.arange(init_timestep - hyperparams['maximum_history_length'],
init_timestep + 1),
state=hyperparams['state'])
online_scene.robot = test_scene.robot
online_scene.calculate_scene_graph(attention_radius=env.attention_radius,
edge_addition_filter=hyperparams['edge_addition_filter'],
edge_removal_filter=hyperparams['edge_removal_filter'])
return Environment(node_type_list=env.node_type_list,
standardization=env.standardization,
scenes=[online_scene],
attention_radius=env.attention_radius,
robot_type=env.robot_type)
def get_maps_for_input(input_dict, scene, hyperparams):
scene_maps = list()
scene_pts = list()
heading_angles = list()
patch_sizes = list()
nodes_with_maps = list()
for node in input_dict:
if node.type in hyperparams['map_encoder']:
x = input_dict[node]
me_hyp = hyperparams['map_encoder'][node.type]
if 'heading_state_index' in me_hyp:
heading_state_index = me_hyp['heading_state_index']
# We have to rotate the map in the opposit direction of the agent to match them
if type(heading_state_index) is list: # infer from velocity or heading vector
heading_angle = -np.arctan2(x[-1, heading_state_index[1]],
x[-1, heading_state_index[0]]) * 180 / np.pi
else:
heading_angle = -x[-1, heading_state_index] * 180 / np.pi
else:
heading_angle = None
scene_map = scene.map[node.type]
map_point = x[-1, :2]
patch_size = hyperparams['map_encoder'][node.type]['patch_size']
scene_maps.append(scene_map)
scene_pts.append(map_point)
heading_angles.append(heading_angle)
patch_sizes.append(patch_size)
nodes_with_maps.append(node)
if heading_angles[0] is None:
heading_angles = None
else:
heading_angles = torch.Tensor(heading_angles)
maps = scene_maps[0].get_cropped_maps_from_scene_map_batch(scene_maps,
scene_pts=torch.Tensor(scene_pts),
patch_size=patch_sizes[0],
rotation=heading_angles)
maps_dict = {node: maps[[i]] for i, node in enumerate(nodes_with_maps)}
return maps_dict
class InferenceServer:
def __init__(self, config: dict):
self.config = config
context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr)
print(self.prediction_socket)
def run(self):
if self.config.seed is not None:
random.seed(self.config.seed)
np.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.config.seed)
# Choose one of the model directory names under the experiment/*/models folders.
# Possibilities are 'vel_ee', 'int_ee', 'int_ee_me', or 'robot'
# model_dir = os.path.join(self.config.log_dir, 'int_ee')
# model_dir = 'models/models_04_Oct_2023_21_04_48_eth_vel_ar3'
# Load hyperparameters from json
config_file = os.path.join(self.config.model_dir, self.config.conf)
if not os.path.exists(config_file):
raise ValueError('Config json not found!')
with open(config_file, 'r') as conf_json:
hyperparams = json.load(conf_json)
# Add hyperparams from arguments
hyperparams['dynamic_edges'] = self.config.dynamic_edges
hyperparams['edge_state_combine_method'] = self.config.edge_state_combine_method
hyperparams['edge_influence_combine_method'] = self.config.edge_influence_combine_method
hyperparams['edge_addition_filter'] = self.config.edge_addition_filter
hyperparams['edge_removal_filter'] = self.config.edge_removal_filter
hyperparams['batch_size'] = self.config.batch_size
hyperparams['k_eval'] = self.config.k_eval
hyperparams['offline_scene_graph'] = self.config.offline_scene_graph
hyperparams['incl_robot_node'] = self.config.incl_robot_node
hyperparams['edge_encoding'] = not self.config.no_edge_encoding
hyperparams['use_map_encoding'] = self.config.map_encoding
output_save_dir = os.path.join(self.config.output_dir, 'pred_figs')
pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)
with open(self.config.eval_data_dict, 'rb') as f:
eval_env = dill.load(f, encoding='latin1')
if eval_env.robot_type is None and hyperparams['incl_robot_node']:
eval_env.robot_type = eval_env.NodeType[0] # TODO: Make more general, allow the user to specify?
for scene in eval_env.scenes:
scene.add_robot_from_nodes(eval_env.robot_type)
logger.info('Loaded data from %s' % (self.config.eval_data_dict,))
# Creating a dummy environment with a single scene that contains information about the world.
# When using this code, feel free to use whichever scene index or initial timestep you wish.
scene_idx = 0
# You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1],
# so that you can immediately start incremental inference from the 3rd timestep onwards.
init_timestep = 1
eval_scene = eval_env.scenes[scene_idx]
online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep)
model_registrar = ModelRegistrar(self.config.model_dir, self.config.eval_device)
model_registrar.load_models(iter_num=100)
trajectron = OnlineTrajectron(model_registrar,
hyperparams,
self.config.eval_device)
# If you want to see what different robot futures do to the predictions, uncomment this line as well as
# related "... += adjustment" lines below.
# adjustment = np.stack([np.arange(13)/float(i*2.0) for i in range(6, 12)], axis=1)
# Here's how you'd incrementally run the model, e.g. with streaming data.
trajectron.set_environment(online_env, init_timestep)
for timestep in range(init_timestep + 1, eval_scene.timesteps):
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
maps = None
if hyperparams['use_map_encoding']:
maps = get_maps_for_input(input_dict, eval_scene, hyperparams)
robot_present_and_future = None
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
robot_present_and_future = eval_scene.robot.get(np.array([timestep,
timestep + hyperparams['prediction_horizon']]),
hyperparams['state'][eval_scene.robot.type],
padding=0.0)
robot_present_and_future = np.stack([robot_present_and_future, robot_present_and_future], axis=0)
# robot_present_and_future += adjustment
start = time.time()
dists, preds = trajectron.incremental_forward(input_dict,
maps,
prediction_horizon=6,
num_samples=51,
robot_present_and_future=robot_present_and_future,
full_dist=True)
end = time.time()
logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
1. / (end - start), len(trajectron.nodes),
trajectron.scene_graph.get_num_edges()))
# unsure what this bit from online_prediction.py does:
# detailed_preds_dict = dict()
# for node in eval_scene.nodes:
# if node in preds:
# detailed_preds_dict[node] = preds[node]
#adapted from trajectron.visualization
# prediction_dict provides the actual predictions
# histories_dict provides the trajectory used for prediction
# futures_dict is the Ground Truth, which is unvailable in an online setting
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
eval_scene.dt,
hyperparams['maximum_history_length'],
hyperparams['prediction_horizon']
)
assert(len(prediction_dict.keys()) <= 1)
if len(prediction_dict.keys()) == 0:
return
ts_key = list(prediction_dict.keys())[0]
prediction_dict = prediction_dict[ts_key]
histories_dict = histories_dict[ts_key]
futures_dict = futures_dict[ts_key]
response = {}
for node in histories_dict:
history = histories_dict[node]
# future = futures_dict[node]
predictions = prediction_dict[node]
if np.isnan(history[-1]).any():
continue
response[node.id] = {
'id': node.id,
'history': history.tolist(),
'predictions': predictions[0].tolist() # use batch 0
}
data = json.dumps(response)
self.prediction_socket.send_string(data)
# time.sleep(1)
# print(prediction_dict)
# print(histories_dict)
# print(futures_dict)
def run_inference_server(config):
s = InferenceServer(config)
s.run()

View file

@ -0,0 +1,162 @@
from argparse import Namespace
import asyncio
import logging
from typing import Set, Union, Dict, Any
from typing_extensions import Self
from urllib.error import HTTPError
import tornado.ioloop
import tornado.web
import tornado.websocket
import zmq
import zmq.asyncio
logger = logging.getLogger("trajpred.forwarder")
class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler):
def initialize(self, zmq_socket: zmq.asyncio.Socket):
self.zmq_socket = zmq_socket
async def on_message(self, message):
logger.debug(f"recieve msg")
try:
await self.zmq_socket.send_string(message)
# msg = json.loads(message)
except Exception as e:
# self.send({'alert': 'Invalid request: {}'.format(e)})
logger.exception(e)
# self.write_message(u"You said: " + message)
def open(self, p=None):
logger.info(f"connected {self.request.remote_ip}")
# client disconnected
def on_close(self):
logger.info(f"Client disconnected: {self.request.remote_ip}")
class WebSocketPredictionHandler(tornado.websocket.WebSocketHandler):
connections: Set[Self] = set()
def initialize(self, config):
self.config = config
def on_message(self, message):
logger.warning(f"Receiving message on send-only ws handler: {message}")
def open(self, p=None):
logger.info(f"Prediction WS connected {self.request.remote_ip}")
self.__class__.connections.add(self)
# client disconnected
def on_close(self):
self.__class__.rmConnection(self)
logger.info(f"Client disconnected: {self.request.remote_ip}")
@classmethod
def rmConnection(cls, client):
if client not in cls.connections:
return
cls.connections.remove(client)
@classmethod
def hasConnection(cls, client):
return client in cls.connections
@classmethod
def write_to_clients(cls, msg: Union[bytes, str, Dict[str, Any]]):
if msg is None:
logger.critical("Tried to send 'none'")
return
toRemove = []
for client in cls.connections:
try:
client.write_message(msg)
except tornado.websocket.WebSocketClosedError as e:
logger.warning(f"Not properly closed websocket connection")
toRemove.append(client) # If we remove it here from the set we get an exception about changing set size during iteration
for client in toRemove:
cls.rmConnection(client)
class DemoHandler(tornado.web.RequestHandler):
def initialize(self, config: Namespace):
self.config = config
def get(self):
self.render("index.html", ws_port=self.config.ws_port)
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
def set_extra_headers(self, path):
"""For subclass to add extra headers to the response"""
if path[-5:] == ".html":
self.set_header("Access-Control-Allow-Origin", "*")
if path[-4:] == ".svg":
self.set_header("Content-Type", "image/svg+xml")
class WsRouter:
def __init__(self, config: Namespace):
self.config = config
context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.connect(config.zmq_prediction_addr)
self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.application = tornado.web.Application(
[
(
r"/ws/trajectory",
WebSocketTrajectoryHandler,
{
"zmq_socket": self.trajectory_socket
},
),
(
r"/ws/prediction",
WebSocketPredictionHandler,
{
"config": config,
},
),
(r"/", DemoHandler, {"config": config}),
# (r"/(.*)", StaticFileWithHeaderHandler, {"config": config, "index": 'index.html'}),
],
template_path = 'trajpred/web/',
compiled_template_cache=False)
def start(self):
evt_loop = asyncio.new_event_loop()
asyncio.set_event_loop(evt_loop)
# loop = tornado.ioloop.IOLoop.current()
logger.info(f"Listen on {self.config.ws_port}")
self.application.listen(self.config.ws_port)
loop = asyncio.get_event_loop()
task = evt_loop.create_task(self.prediction_forwarder())
evt_loop.run_forever()
async def prediction_forwarder(self):
logger.info("Starting prediction forwarder")
while True:
msg = await self.prediction_socket.recv_string()
logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg)
def run_ws_forwarder(config: Namespace):
router = WsRouter(config)
router.start()

170
trajpred/web/index.html Normal file
View file

@ -0,0 +1,170 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Trajectory Prediction Browser Test</title>
<style>
body {
background: black;
}
#field {
background: white;
width: 100%;
height: 100%;
}
</style>
</head>
<body>
<canvas id="field" width="1500" height="1500">
</canvas>
<script>
// minified https://github.com/joewalnes/reconnecting-websocket
!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a});
</script>
<script>
// map the field to coordinates of our dummy tracker
const field_range = { x: [-10, 10], y: [-10, 10] }
// Create WebSocket connection.
const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`);
const prediction_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/prediction`);
let is_moving = false;
const fieldEl = document.getElementById('field');
let current_data = {}
// Listen for messages
prediction_socket.addEventListener("message", (event) => {
// console.log("Message from server ", event.data);
current_data = JSON.parse(event.data);
});
function getMousePos(canvas, evt) {
const rect = canvas.getBoundingClientRect();
return {
x: evt.clientX - rect.left,
y: evt.clientY - rect.top
};
}
function mouse_coordinates_to_position(coordinates) {
const x_range = field_range.x[1] - field_range.x[0]
const x = (coordinates.x / fieldEl.clientWidth) * x_range + field_range.x[0]
const y_range = field_range.y[1] - field_range.y[0]
const y = (coordinates.y / fieldEl.clientWidth) * y_range + field_range.y[0]
return { x: x, y: y }
}
function position_to_canvas_coordinate(position) {
const x_range = field_range.x[1] - field_range.x[0]
const y_range = field_range.y[1] - field_range.y[0]
const x = Array.isArray(position) ? position[0] : position.x;
const y = Array.isArray(position) ? position[1] : position.y;
return {
x: (x - field_range.x[0]) * fieldEl.width / x_range,
y: (y - field_range.y[0]) * fieldEl.width / y_range,
}
}
// helper function so we can spread
function coord_as_list(coord) {
return [coord.x, coord.y]
}
let tracker = {}
let person_counter = 0
class Person {
constructor(id) {
this.id = id;
this.history = [];
this.prediction = []
}
addToHistory(position) {
this.history.push(position);
}
}
fieldEl.addEventListener('mousedown', (event) => {
person_counter++;
tracker[person_counter] = new Person(person_counter);
is_moving = true;
const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos)
tracker[person_counter].addToHistory(position);
trajectory_socket.send(JSON.stringify(tracker))
});
fieldEl.addEventListener('mousemove', (event) => {
if (!is_moving) return;
const mousePos = getMousePos(fieldEl, event);
const position = mouse_coordinates_to_position(mousePos)
tracker[person_counter].addToHistory(position);
trajectory_socket.send(JSON.stringify(tracker))
});
document.addEventListener('mouseup', (e) => {
is_moving = false;
tracker = {}
})
const ctx = fieldEl.getContext("2d");
function drawFrame() {
ctx.clearRect(0, 0, fieldEl.width, fieldEl.height);
ctx.save();
for (let id in current_data) {
const person = current_data[id];
if (person.history.length > 1) {
const hist = structuredClone(person.history)
// draw current position:
ctx.beginPath()
ctx.arc(
...coord_as_list(position_to_canvas_coordinate(hist[hist.length - 1])),
5, //radius
0, 2 * Math.PI);
ctx.fill()
ctx.beginPath()
ctx.lineWidth = 3;
ctx.strokeStyle = "#325FA2";
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(hist.shift())));
for (const position of hist) {
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
}
ctx.stroke();
}
if(person.hasOwnProperty('predictions') && person.predictions.length > 0) {
// multiple predictions can be sampled
person.predictions.forEach((prediction, i) => {
ctx.beginPath()
ctx.lineWidth = i === 1 ? 3 : 0.2;
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
// start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
for (const position of prediction) {
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
}
ctx.stroke();
});
}
}
ctx.restore();
window.requestAnimationFrame(drawFrame);
}
window.requestAnimationFrame(drawFrame);
</script>
</body>
</html>