Compare commits
No commits in common. "a4a4990035bba3aa4d960713f6660b33f2e3a1b2" and "1031c7bd1a444273af378c1ec1dcca907ba59830" have entirely different histories.
a4a4990035
...
1031c7bd1a
29 changed files with 50 additions and 4124 deletions
|
@ -1 +0,0 @@
|
||||||
3.10.4
|
|
|
@ -7,11 +7,11 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
sys.path.append("../../")
|
sys.path.append("../../trajectron")
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trajectron.model.model_registrar import ModelRegistrar
|
from model.model_registrar import ModelRegistrar
|
||||||
from trajectron.model.trajectron import Trajectron
|
from model.trajectron import Trajectron
|
||||||
import trajectron.evaluation as evaluation
|
import evaluation
|
||||||
|
|
||||||
seed = 0
|
seed = 0
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
|
@ -4,10 +4,10 @@ import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import dill
|
import dill
|
||||||
|
|
||||||
sys.path.append("../../")
|
sys.path.append("../../trajectron")
|
||||||
from trajectron.environment import Environment, Scene, Node
|
from environment import Environment, Scene, Node
|
||||||
from trajectron.utils import maybe_makedirs
|
from utils import maybe_makedirs
|
||||||
from trajectron.environment import derivative_of
|
from environment import derivative_of
|
||||||
|
|
||||||
desired_max_time = 100
|
desired_max_time = 100
|
||||||
pred_indices = [2, 3]
|
pred_indices = [2, 3]
|
||||||
|
|
3211
poetry.lock
generated
3211
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,50 +0,0 @@
|
||||||
[tool.poetry]
|
|
||||||
name = "Trajectron-plus-plus"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "Predict trajectories for anomaly detection"
|
|
||||||
authors = ["Ruben van de Ven <git@rubenvandeven.com>"]
|
|
||||||
readme = "README.md"
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
|
||||||
python = "^3.9,<3.12"
|
|
||||||
numpy = "^1.24.3"
|
|
||||||
opencv-python = "^4.7.0.72"
|
|
||||||
ipykernel = "^6.22.0"
|
|
||||||
torch = [
|
|
||||||
{ version="1.12.1" },
|
|
||||||
# { url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp38-cp38-linux_x86_64.whl", markers = "python_version ~= '3.8' and sys_platform == 'linux'" },
|
|
||||||
{ url = "https://download.pytorch.org/whl/cu113/torch-1.12.1%2Bcu113-cp310-cp310-linux_x86_64.whl", markers = "python_version ~= '3.10' and sys_platform == 'linux'" },
|
|
||||||
]
|
|
||||||
#torchvision = [
|
|
||||||
# { version="0.13.1" },
|
|
||||||
# { url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp38-cp38-linux_x86_64.whl", markers = "python_version ~= '3.8' and sys_platform == 'linux'" },
|
|
||||||
# { url = "https://download.pytorch.org/whl/cu113/torchvision-0.13.1%2Bcu113-cp310-cp310-linux_x86_64.whl", markers = "python_version ~= '3.10' and sys_platform == 'linux'" },
|
|
||||||
#]
|
|
||||||
|
|
||||||
#av = "^10.0.0"
|
|
||||||
matplotlib = "^3.5"
|
|
||||||
#numba = "^0.57.0"
|
|
||||||
#scikit-image = "^0.20.0"
|
|
||||||
#scikit-learn = "0.22.1"
|
|
||||||
#filterpy = "^1.4.5"
|
|
||||||
tqdm = "^4.65.0"
|
|
||||||
#ipywidgets = "^8.0.6"
|
|
||||||
#deep-sort-realtime = "^1.3.2"
|
|
||||||
scipy = "^1.11.3"
|
|
||||||
pandas = "^2.1.1"
|
|
||||||
orjson = "^3.9.7"
|
|
||||||
nuscenes-devkit = "^1.1.11"
|
|
||||||
pyquaternion = "^0.9.9"
|
|
||||||
dill = "^0.3.7"
|
|
||||||
ncls = "^0.0.68"
|
|
||||||
notebook = "^7.0.4"
|
|
||||||
scikit-learn = "^1.3.1"
|
|
||||||
seaborn = "^0.13.0"
|
|
||||||
setuptools = "^68.2.2"
|
|
||||||
tensorboard = "1.14"
|
|
||||||
tensorboardx = "^2.6.2.2"
|
|
||||||
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["poetry-core"]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
|
@ -1,5 +0,0 @@
|
||||||
from trajpred import plumber
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
plumber.start()
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from collections.abc import Sequence
|
from collections import Sequence, OrderedDict
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
class RingBuffer(Sequence):
|
class RingBuffer(Sequence):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from trajectron.model.dataset.homography_warper import get_rotation_matrix2d, warp_affine_crop
|
from model.dataset.homography_warper import get_rotation_matrix2d, warp_affine_crop
|
||||||
|
|
||||||
|
|
||||||
class Map(object):
|
class Map(object):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from trajectron.environment import DoubleHeaderNumpyArray
|
from environment import DoubleHeaderNumpyArray
|
||||||
from ncls import NCLS
|
from ncls import NCLS
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -135,10 +135,10 @@ class TemporalSceneGraph(object):
|
||||||
position_cube = np.full((total_timesteps, N, 2), np.nan)
|
position_cube = np.full((total_timesteps, N, 2), np.nan)
|
||||||
|
|
||||||
adj_cube = np.zeros((total_timesteps, N, N), dtype=np.int8)
|
adj_cube = np.zeros((total_timesteps, N, N), dtype=np.int8)
|
||||||
dist_cube = np.zeros((total_timesteps, N, N), dtype=float)
|
dist_cube = np.zeros((total_timesteps, N, N), dtype=np.float)
|
||||||
|
|
||||||
node_type_mat = np.zeros((N, N), dtype=np.int8)
|
node_type_mat = np.zeros((N, N), dtype=np.int8)
|
||||||
node_attention_mat = np.zeros((N, N), dtype=float)
|
node_attention_mat = np.zeros((N, N), dtype=np.float)
|
||||||
|
|
||||||
for node_idx, node in enumerate(nodes):
|
for node_idx, node in enumerate(nodes):
|
||||||
if online:
|
if online:
|
||||||
|
|
|
@ -2,8 +2,8 @@ import numpy as np
|
||||||
from scipy.interpolate import RectBivariateSpline
|
from scipy.interpolate import RectBivariateSpline
|
||||||
from scipy.ndimage import binary_dilation
|
from scipy.ndimage import binary_dilation
|
||||||
from scipy.stats import gaussian_kde
|
from scipy.stats import gaussian_kde
|
||||||
from trajectron.utils import prediction_output_to_trajectories
|
from utils import prediction_output_to_trajectories
|
||||||
import trajectron.visualization
|
import visualization
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributions as td
|
import torch.distributions as td
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from trajectron.model.model_utils import ModeKeys
|
from model.model_utils import ModeKeys
|
||||||
|
|
||||||
|
|
||||||
class DiscreteLatent(object):
|
class DiscreteLatent(object):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributions as td
|
import torch.distributions as td
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from trajectron.model.model_utils import to_one_hot
|
from model.model_utils import to_one_hot
|
||||||
|
|
||||||
|
|
||||||
class GMM2D(td.Distribution):
|
class GMM2D(td.Distribution):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from trajectron.model.dynamics import Dynamic
|
from model.dynamics import Dynamic
|
||||||
|
|
||||||
|
|
||||||
class Linear(Dynamic):
|
class Linear(Dynamic):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from trajectron.model.dynamics import Dynamic
|
from model.dynamics import Dynamic
|
||||||
from trajectron.utils import block_diag
|
from utils import block_diag
|
||||||
from trajectron.model.components import GMM2D
|
from model.components import GMM2D
|
||||||
|
|
||||||
|
|
||||||
class SingleIntegrator(Dynamic):
|
class SingleIntegrator(Dynamic):
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from trajectron.model.dynamics import Dynamic
|
from model.dynamics import Dynamic
|
||||||
from trajectron.utils import block_diag
|
from utils import block_diag
|
||||||
from trajectron.model.components import GMM2D
|
from model.components import GMM2D
|
||||||
|
|
||||||
|
|
||||||
class Unicycle(Dynamic):
|
class Unicycle(Dynamic):
|
||||||
|
|
|
@ -2,10 +2,10 @@ import warnings
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from trajectron.model.components import *
|
from model.components import *
|
||||||
from trajectron.model.model_utils import *
|
from model.model_utils import *
|
||||||
import trajectron.model.dynamics as dynamic_module
|
import model.dynamics as dynamic_module
|
||||||
from trajectron.environment.scene_graph import DirectedEdge
|
from environment.scene_graph import DirectedEdge
|
||||||
|
|
||||||
|
|
||||||
class MultimodalGenerativeCVAE(object):
|
class MultimodalGenerativeCVAE(object):
|
||||||
|
|
|
@ -1,23 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import pickle
|
|
||||||
|
|
||||||
def get_model_device(model):
|
def get_model_device(model):
|
||||||
return next(model.parameters()).device
|
return next(model.parameters()).device
|
||||||
|
|
||||||
class PickleModuleCompatibility:
|
|
||||||
'''
|
|
||||||
Migrating Trajectron++ to a module structure
|
|
||||||
while maintaining compatibility with models generated
|
|
||||||
before the migration
|
|
||||||
'''
|
|
||||||
class Unpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
if module == 'model' or module[:6] == 'model.':
|
|
||||||
module = 'trajectron.' + module
|
|
||||||
return super().find_class(module, name)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistrar(nn.Module):
|
class ModelRegistrar(nn.Module):
|
||||||
def __init__(self, model_dir, device):
|
def __init__(self, model_dir, device):
|
||||||
|
@ -77,7 +65,7 @@ class ModelRegistrar(nn.Module):
|
||||||
|
|
||||||
print('')
|
print('')
|
||||||
print('Loading from ' + save_path)
|
print('Loading from ' + save_path)
|
||||||
self.model_dict = torch.load(save_path, map_location=self.device, pickle_module=PickleModuleCompatibility)
|
self.model_dict = torch.load(save_path, map_location=self.device)
|
||||||
print('Loaded!')
|
print('Loaded!')
|
||||||
print('')
|
print('')
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,13 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import defaultdict, Counter
|
from collections import defaultdict, Counter
|
||||||
from trajectron.model.components import *
|
from model.components import *
|
||||||
from trajectron.model.model_utils import *
|
from model.model_utils import *
|
||||||
from trajectron.model.dataset import get_relative_robot_traj
|
from model.dataset import get_relative_robot_traj
|
||||||
import trajectron.model.dynamics as dynamic_module
|
import model.dynamics as dynamic_module
|
||||||
from trajectron.model.mgcvae import MultimodalGenerativeCVAE
|
from model.mgcvae import MultimodalGenerativeCVAE
|
||||||
from trajectron.environment.scene_graph import DirectedEdge
|
from environment.scene_graph import DirectedEdge
|
||||||
from trajectron.environment.node_type import NodeType
|
from environment.node_type import NodeType
|
||||||
|
|
||||||
|
|
||||||
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
class OnlineMultimodalGenerativeCVAE(MultimodalGenerativeCVAE):
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from trajectron.model.trajectron import Trajectron
|
from model.trajectron import Trajectron
|
||||||
from trajectron.model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
|
from model.online.online_mgcvae import OnlineMultimodalGenerativeCVAE
|
||||||
from trajectron.model.model_utils import ModeKeys
|
from model.model_utils import ModeKeys
|
||||||
from trajectron.environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
|
from environment import RingBuffer, TemporalSceneGraph, SceneGraph, derivative_of
|
||||||
|
|
||||||
|
|
||||||
class OnlineTrajectron(Trajectron):
|
class OnlineTrajectron(Trajectron):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from trajectron.model.mgcvae import MultimodalGenerativeCVAE
|
from model.mgcvae import MultimodalGenerativeCVAE
|
||||||
from trajectron.model.dataset import get_timesteps_data, restore
|
from model.dataset import get_timesteps_data, restore
|
||||||
|
|
||||||
|
|
||||||
class Trajectron(object):
|
class Trajectron(object):
|
||||||
|
|
|
@ -5,12 +5,13 @@ import torch
|
||||||
import dill
|
import dill
|
||||||
import random
|
import random
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import evaluation
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import trajectron.visualization as vis
|
import visualization as vis
|
||||||
from trajectron.argument_parser import args
|
from argument_parser import args
|
||||||
from trajectron.model.online.online_trajectron import OnlineTrajectron
|
from model.online.online_trajectron import OnlineTrajectron
|
||||||
from trajectron.model.model_registrar import ModelRegistrar
|
from model.model_registrar import ModelRegistrar
|
||||||
from trajectron.environment import Environment, Scene
|
from environment import Environment, Scene
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
if not torch.cuda.is_available() or args.device == 'cpu':
|
if not torch.cuda.is_available() or args.device == 'cpu':
|
|
@ -1,4 +1,4 @@
|
||||||
from trajectron.utils import prediction_output_to_trajectories
|
from utils import prediction_output_to_trajectories
|
||||||
from scipy import linalg
|
from scipy import linalg
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.patches as patches
|
import matplotlib.patches as patches
|
||||||
|
|
|
@ -1,153 +0,0 @@
|
||||||
import argparse
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--verbose',
|
|
||||||
'-v',
|
|
||||||
help="Increase verbosity. Add multiple times to increase further.",
|
|
||||||
action='count', default=0
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--remote-log-addr',
|
|
||||||
help="Connect to a remote logger like cutelog. Specify the ip",
|
|
||||||
type=str,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--remote-log-port',
|
|
||||||
help="Connect to a remote logger like cutelog. Specify the port",
|
|
||||||
type=int,
|
|
||||||
default=19996
|
|
||||||
)
|
|
||||||
|
|
||||||
# parser.add_argument('--foo')
|
|
||||||
inference_parser = parser.add_argument_group('inference server')
|
|
||||||
connection_parser = parser.add_argument_group('connection')
|
|
||||||
|
|
||||||
inference_parser.add_argument("--model_dir",
|
|
||||||
help="directory with the model to use for inference",
|
|
||||||
type=str, # TODO: make into Path
|
|
||||||
default='./experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3')
|
|
||||||
|
|
||||||
inference_parser.add_argument("--conf",
|
|
||||||
help="path to json config file for hyperparameters, relative to model_dir",
|
|
||||||
type=str,
|
|
||||||
default='config.json')
|
|
||||||
|
|
||||||
# Model Parameters (hyperparameters)
|
|
||||||
inference_parser.add_argument("--offline_scene_graph",
|
|
||||||
help="whether to precompute the scene graphs offline, options are 'no' and 'yes'",
|
|
||||||
type=str,
|
|
||||||
default='yes')
|
|
||||||
|
|
||||||
inference_parser.add_argument("--dynamic_edges",
|
|
||||||
help="whether to use dynamic edges or not, options are 'no' and 'yes'",
|
|
||||||
type=str,
|
|
||||||
default='yes')
|
|
||||||
|
|
||||||
inference_parser.add_argument("--edge_state_combine_method",
|
|
||||||
help="the method to use for combining edges of the same type",
|
|
||||||
type=str,
|
|
||||||
default='sum')
|
|
||||||
|
|
||||||
inference_parser.add_argument("--edge_influence_combine_method",
|
|
||||||
help="the method to use for combining edge influences",
|
|
||||||
type=str,
|
|
||||||
default='attention')
|
|
||||||
|
|
||||||
inference_parser.add_argument('--edge_addition_filter',
|
|
||||||
nargs='+',
|
|
||||||
help="what scaling to use for edges as they're created",
|
|
||||||
type=float,
|
|
||||||
default=[0.25, 0.5, 0.75, 1.0]) # We don't automatically pad left with 0.0, if you want a sharp
|
|
||||||
# and short edge addition, then you need to have a 0.0 at the
|
|
||||||
# beginning, e.g. [0.0, 1.0].
|
|
||||||
|
|
||||||
inference_parser.add_argument('--edge_removal_filter',
|
|
||||||
nargs='+',
|
|
||||||
help="what scaling to use for edges as they're removed",
|
|
||||||
type=float,
|
|
||||||
default=[1.0, 0.0]) # We don't automatically pad right with 0.0, if you want a sharp drop off like
|
|
||||||
# the default, then you need to have a 0.0 at the end.
|
|
||||||
|
|
||||||
|
|
||||||
inference_parser.add_argument('--incl_robot_node',
|
|
||||||
help="whether to include a robot node in the graph or simply model all agents",
|
|
||||||
action='store_true')
|
|
||||||
|
|
||||||
inference_parser.add_argument('--map_encoding',
|
|
||||||
help="Whether to use map encoding or not",
|
|
||||||
action='store_true')
|
|
||||||
|
|
||||||
inference_parser.add_argument('--no_edge_encoding',
|
|
||||||
help="Whether to use neighbors edge encoding",
|
|
||||||
action='store_true')
|
|
||||||
|
|
||||||
|
|
||||||
inference_parser.add_argument('--batch_size',
|
|
||||||
help='training batch size',
|
|
||||||
type=int,
|
|
||||||
default=256)
|
|
||||||
|
|
||||||
inference_parser.add_argument('--k_eval',
|
|
||||||
help='how many samples to take during evaluation',
|
|
||||||
type=int,
|
|
||||||
default=25)
|
|
||||||
|
|
||||||
# Data Parameters
|
|
||||||
inference_parser.add_argument("--eval_data_dict",
|
|
||||||
help="what file to load for evaluation data (WHEN NOT USING LIVE DATA)",
|
|
||||||
type=str,
|
|
||||||
default='./experiments/processed/eth_test.pkl')
|
|
||||||
|
|
||||||
inference_parser.add_argument("--output_dir",
|
|
||||||
help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)",
|
|
||||||
type=str,
|
|
||||||
default='../experiments/pedestrians/OUT/test_inference')
|
|
||||||
|
|
||||||
|
|
||||||
# inference_parser.add_argument('--device',
|
|
||||||
# help='what device to perform training on',
|
|
||||||
# type=str,
|
|
||||||
# default='cuda:0')
|
|
||||||
|
|
||||||
inference_parser.add_argument("--eval_device",
|
|
||||||
help="what device to use during inference",
|
|
||||||
type=str,
|
|
||||||
default="cpu")
|
|
||||||
|
|
||||||
|
|
||||||
inference_parser.add_argument('--seed',
|
|
||||||
help='manual seed to use, default is 123',
|
|
||||||
type=int,
|
|
||||||
default=123)
|
|
||||||
|
|
||||||
|
|
||||||
# Internal connections.
|
|
||||||
|
|
||||||
connection_parser.add_argument('--zmq-trajectory-addr',
|
|
||||||
help='Manually specity communication addr for the trajectory messages',
|
|
||||||
type=str,
|
|
||||||
default="ipc:///tmp/feeds/traj")
|
|
||||||
|
|
||||||
connection_parser.add_argument('--zmq-camera-stream-addr',
|
|
||||||
help='Manually specity communication addr for the camera stream messages',
|
|
||||||
type=str,
|
|
||||||
default="ipc:///tmp/feeds/img")
|
|
||||||
|
|
||||||
connection_parser.add_argument('--zmq-prediction-addr',
|
|
||||||
help='Manually specity communication addr for the prediction messages',
|
|
||||||
type=str,
|
|
||||||
default="ipc:///tmp/feeds/preds")
|
|
||||||
|
|
||||||
|
|
||||||
connection_parser.add_argument('--ws-port',
|
|
||||||
help='Port to listen for incomming websocket connections. Also serves the testing html-page.',
|
|
||||||
type=int,
|
|
||||||
default=8888)
|
|
||||||
|
|
||||||
connection_parser.add_argument('--bypass-prediction',
|
|
||||||
help='For debugging purpose: websocket input immediately to output',
|
|
||||||
action='store_true')
|
|
||||||
|
|
|
@ -1,40 +0,0 @@
|
||||||
import logging
|
|
||||||
from logging.handlers import SocketHandler
|
|
||||||
from multiprocessing import Process, Queue
|
|
||||||
from trajpred.config import parser
|
|
||||||
from trajpred.prediction_server import InferenceServer, run_inference_server
|
|
||||||
from trajpred.socket_forwarder import run_ws_forwarder
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("trajpred.plumbing")
|
|
||||||
|
|
||||||
def start():
|
|
||||||
args = parser.parse_args()
|
|
||||||
loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO
|
|
||||||
logging.basicConfig(
|
|
||||||
level=loglevel,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.remote_log_addr:
|
|
||||||
logging.captureWarnings(True)
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
root_logger.setLevel(logging.NOTSET) # to send all records to cutelog
|
|
||||||
socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port)
|
|
||||||
root_logger.addHandler(socket_handler)
|
|
||||||
|
|
||||||
# instantiating process with arguments
|
|
||||||
procs = [
|
|
||||||
Process(target=run_ws_forwarder, args=(args,))
|
|
||||||
]
|
|
||||||
if not args.bypass_prediction:
|
|
||||||
procs.append(
|
|
||||||
Process(target=run_inference_server, args=(args,)),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("start")
|
|
||||||
for proc in procs:
|
|
||||||
proc.start()
|
|
||||||
|
|
||||||
for proc in procs:
|
|
||||||
proc.join()
|
|
||||||
|
|
|
@ -1,270 +0,0 @@
|
||||||
# adapted from Trajectron++ online_server.py
|
|
||||||
import logging
|
|
||||||
from multiprocessing import Queue
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
import dill
|
|
||||||
import random
|
|
||||||
import pathlib
|
|
||||||
import numpy as np
|
|
||||||
from trajectron.utils import prediction_output_to_trajectories
|
|
||||||
from trajectron.model.online.online_trajectron import OnlineTrajectron
|
|
||||||
from trajectron.model.model_registrar import ModelRegistrar
|
|
||||||
from trajectron.environment import Environment, Scene
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
import zmq
|
|
||||||
|
|
||||||
logger = logging.getLogger("trajpred.inference")
|
|
||||||
|
|
||||||
|
|
||||||
# if not torch.cuda.is_available() or self.config.device == 'cpu':
|
|
||||||
# self.config.device = torch.device('cpu')
|
|
||||||
# else:
|
|
||||||
# if torch.cuda.device_count() == 1:
|
|
||||||
# # If you have CUDA_VISIBLE_DEVICES set, which you should,
|
|
||||||
# # then this will prevent leftover flag arguments from
|
|
||||||
# # messing with the device allocation.
|
|
||||||
# self.config.device = 'cuda:0'
|
|
||||||
|
|
||||||
# self.config.device = torch.device(self.config.device)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_online_env(env, hyperparams, scene_idx, init_timestep):
|
|
||||||
test_scene = env.scenes[scene_idx]
|
|
||||||
|
|
||||||
online_scene = Scene(timesteps=init_timestep + 1,
|
|
||||||
map=test_scene.map,
|
|
||||||
dt=test_scene.dt)
|
|
||||||
online_scene.nodes = test_scene.get_nodes_clipped_at_time(
|
|
||||||
timesteps=np.arange(init_timestep - hyperparams['maximum_history_length'],
|
|
||||||
init_timestep + 1),
|
|
||||||
state=hyperparams['state'])
|
|
||||||
online_scene.robot = test_scene.robot
|
|
||||||
online_scene.calculate_scene_graph(attention_radius=env.attention_radius,
|
|
||||||
edge_addition_filter=hyperparams['edge_addition_filter'],
|
|
||||||
edge_removal_filter=hyperparams['edge_removal_filter'])
|
|
||||||
|
|
||||||
return Environment(node_type_list=env.node_type_list,
|
|
||||||
standardization=env.standardization,
|
|
||||||
scenes=[online_scene],
|
|
||||||
attention_radius=env.attention_radius,
|
|
||||||
robot_type=env.robot_type)
|
|
||||||
|
|
||||||
|
|
||||||
def get_maps_for_input(input_dict, scene, hyperparams):
|
|
||||||
scene_maps = list()
|
|
||||||
scene_pts = list()
|
|
||||||
heading_angles = list()
|
|
||||||
patch_sizes = list()
|
|
||||||
nodes_with_maps = list()
|
|
||||||
for node in input_dict:
|
|
||||||
if node.type in hyperparams['map_encoder']:
|
|
||||||
x = input_dict[node]
|
|
||||||
me_hyp = hyperparams['map_encoder'][node.type]
|
|
||||||
if 'heading_state_index' in me_hyp:
|
|
||||||
heading_state_index = me_hyp['heading_state_index']
|
|
||||||
# We have to rotate the map in the opposit direction of the agent to match them
|
|
||||||
if type(heading_state_index) is list: # infer from velocity or heading vector
|
|
||||||
heading_angle = -np.arctan2(x[-1, heading_state_index[1]],
|
|
||||||
x[-1, heading_state_index[0]]) * 180 / np.pi
|
|
||||||
else:
|
|
||||||
heading_angle = -x[-1, heading_state_index] * 180 / np.pi
|
|
||||||
else:
|
|
||||||
heading_angle = None
|
|
||||||
|
|
||||||
scene_map = scene.map[node.type]
|
|
||||||
map_point = x[-1, :2]
|
|
||||||
|
|
||||||
patch_size = hyperparams['map_encoder'][node.type]['patch_size']
|
|
||||||
|
|
||||||
scene_maps.append(scene_map)
|
|
||||||
scene_pts.append(map_point)
|
|
||||||
heading_angles.append(heading_angle)
|
|
||||||
patch_sizes.append(patch_size)
|
|
||||||
nodes_with_maps.append(node)
|
|
||||||
|
|
||||||
if heading_angles[0] is None:
|
|
||||||
heading_angles = None
|
|
||||||
else:
|
|
||||||
heading_angles = torch.Tensor(heading_angles)
|
|
||||||
|
|
||||||
maps = scene_maps[0].get_cropped_maps_from_scene_map_batch(scene_maps,
|
|
||||||
scene_pts=torch.Tensor(scene_pts),
|
|
||||||
patch_size=patch_sizes[0],
|
|
||||||
rotation=heading_angles)
|
|
||||||
|
|
||||||
maps_dict = {node: maps[[i]] for i, node in enumerate(nodes_with_maps)}
|
|
||||||
return maps_dict
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceServer:
|
|
||||||
def __init__(self, config: dict):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
context = zmq.Context()
|
|
||||||
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
|
||||||
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
|
||||||
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
|
||||||
|
|
||||||
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
|
|
||||||
self.prediction_socket.bind(config.zmq_prediction_addr)
|
|
||||||
print(self.prediction_socket)
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
|
|
||||||
if self.config.seed is not None:
|
|
||||||
random.seed(self.config.seed)
|
|
||||||
np.random.seed(self.config.seed)
|
|
||||||
torch.manual_seed(self.config.seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(self.config.seed)
|
|
||||||
|
|
||||||
# Choose one of the model directory names under the experiment/*/models folders.
|
|
||||||
# Possibilities are 'vel_ee', 'int_ee', 'int_ee_me', or 'robot'
|
|
||||||
# model_dir = os.path.join(self.config.log_dir, 'int_ee')
|
|
||||||
# model_dir = 'models/models_04_Oct_2023_21_04_48_eth_vel_ar3'
|
|
||||||
|
|
||||||
# Load hyperparameters from json
|
|
||||||
config_file = os.path.join(self.config.model_dir, self.config.conf)
|
|
||||||
if not os.path.exists(config_file):
|
|
||||||
raise ValueError('Config json not found!')
|
|
||||||
with open(config_file, 'r') as conf_json:
|
|
||||||
hyperparams = json.load(conf_json)
|
|
||||||
|
|
||||||
# Add hyperparams from arguments
|
|
||||||
hyperparams['dynamic_edges'] = self.config.dynamic_edges
|
|
||||||
hyperparams['edge_state_combine_method'] = self.config.edge_state_combine_method
|
|
||||||
hyperparams['edge_influence_combine_method'] = self.config.edge_influence_combine_method
|
|
||||||
hyperparams['edge_addition_filter'] = self.config.edge_addition_filter
|
|
||||||
hyperparams['edge_removal_filter'] = self.config.edge_removal_filter
|
|
||||||
hyperparams['batch_size'] = self.config.batch_size
|
|
||||||
hyperparams['k_eval'] = self.config.k_eval
|
|
||||||
hyperparams['offline_scene_graph'] = self.config.offline_scene_graph
|
|
||||||
hyperparams['incl_robot_node'] = self.config.incl_robot_node
|
|
||||||
hyperparams['edge_encoding'] = not self.config.no_edge_encoding
|
|
||||||
hyperparams['use_map_encoding'] = self.config.map_encoding
|
|
||||||
|
|
||||||
output_save_dir = os.path.join(self.config.output_dir, 'pred_figs')
|
|
||||||
pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
|
|
||||||
with open(self.config.eval_data_dict, 'rb') as f:
|
|
||||||
eval_env = dill.load(f, encoding='latin1')
|
|
||||||
|
|
||||||
if eval_env.robot_type is None and hyperparams['incl_robot_node']:
|
|
||||||
eval_env.robot_type = eval_env.NodeType[0] # TODO: Make more general, allow the user to specify?
|
|
||||||
for scene in eval_env.scenes:
|
|
||||||
scene.add_robot_from_nodes(eval_env.robot_type)
|
|
||||||
|
|
||||||
logger.info('Loaded data from %s' % (self.config.eval_data_dict,))
|
|
||||||
|
|
||||||
# Creating a dummy environment with a single scene that contains information about the world.
|
|
||||||
# When using this code, feel free to use whichever scene index or initial timestep you wish.
|
|
||||||
scene_idx = 0
|
|
||||||
|
|
||||||
# You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1],
|
|
||||||
# so that you can immediately start incremental inference from the 3rd timestep onwards.
|
|
||||||
init_timestep = 1
|
|
||||||
|
|
||||||
eval_scene = eval_env.scenes[scene_idx]
|
|
||||||
online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep)
|
|
||||||
|
|
||||||
model_registrar = ModelRegistrar(self.config.model_dir, self.config.eval_device)
|
|
||||||
model_registrar.load_models(iter_num=100)
|
|
||||||
|
|
||||||
trajectron = OnlineTrajectron(model_registrar,
|
|
||||||
hyperparams,
|
|
||||||
self.config.eval_device)
|
|
||||||
|
|
||||||
# If you want to see what different robot futures do to the predictions, uncomment this line as well as
|
|
||||||
# related "... += adjustment" lines below.
|
|
||||||
# adjustment = np.stack([np.arange(13)/float(i*2.0) for i in range(6, 12)], axis=1)
|
|
||||||
|
|
||||||
# Here's how you'd incrementally run the model, e.g. with streaming data.
|
|
||||||
trajectron.set_environment(online_env, init_timestep)
|
|
||||||
|
|
||||||
for timestep in range(init_timestep + 1, eval_scene.timesteps):
|
|
||||||
input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state'])
|
|
||||||
|
|
||||||
maps = None
|
|
||||||
if hyperparams['use_map_encoding']:
|
|
||||||
maps = get_maps_for_input(input_dict, eval_scene, hyperparams)
|
|
||||||
|
|
||||||
robot_present_and_future = None
|
|
||||||
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
|
|
||||||
robot_present_and_future = eval_scene.robot.get(np.array([timestep,
|
|
||||||
timestep + hyperparams['prediction_horizon']]),
|
|
||||||
hyperparams['state'][eval_scene.robot.type],
|
|
||||||
padding=0.0)
|
|
||||||
robot_present_and_future = np.stack([robot_present_and_future, robot_present_and_future], axis=0)
|
|
||||||
# robot_present_and_future += adjustment
|
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
dists, preds = trajectron.incremental_forward(input_dict,
|
|
||||||
maps,
|
|
||||||
prediction_horizon=6,
|
|
||||||
num_samples=51,
|
|
||||||
robot_present_and_future=robot_present_and_future,
|
|
||||||
full_dist=True)
|
|
||||||
end = time.time()
|
|
||||||
logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
|
|
||||||
1. / (end - start), len(trajectron.nodes),
|
|
||||||
trajectron.scene_graph.get_num_edges()))
|
|
||||||
|
|
||||||
# unsure what this bit from online_prediction.py does:
|
|
||||||
# detailed_preds_dict = dict()
|
|
||||||
# for node in eval_scene.nodes:
|
|
||||||
# if node in preds:
|
|
||||||
# detailed_preds_dict[node] = preds[node]
|
|
||||||
|
|
||||||
#adapted from trajectron.visualization
|
|
||||||
# prediction_dict provides the actual predictions
|
|
||||||
# histories_dict provides the trajectory used for prediction
|
|
||||||
# futures_dict is the Ground Truth, which is unvailable in an online setting
|
|
||||||
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
|
|
||||||
eval_scene.dt,
|
|
||||||
hyperparams['maximum_history_length'],
|
|
||||||
hyperparams['prediction_horizon']
|
|
||||||
)
|
|
||||||
|
|
||||||
assert(len(prediction_dict.keys()) <= 1)
|
|
||||||
if len(prediction_dict.keys()) == 0:
|
|
||||||
return
|
|
||||||
ts_key = list(prediction_dict.keys())[0]
|
|
||||||
|
|
||||||
prediction_dict = prediction_dict[ts_key]
|
|
||||||
histories_dict = histories_dict[ts_key]
|
|
||||||
futures_dict = futures_dict[ts_key]
|
|
||||||
|
|
||||||
response = {}
|
|
||||||
|
|
||||||
for node in histories_dict:
|
|
||||||
history = histories_dict[node]
|
|
||||||
# future = futures_dict[node]
|
|
||||||
predictions = prediction_dict[node]
|
|
||||||
|
|
||||||
if np.isnan(history[-1]).any():
|
|
||||||
continue
|
|
||||||
|
|
||||||
response[node.id] = {
|
|
||||||
'id': node.id,
|
|
||||||
'history': history.tolist(),
|
|
||||||
'predictions': predictions[0].tolist() # use batch 0
|
|
||||||
}
|
|
||||||
|
|
||||||
data = json.dumps(response)
|
|
||||||
self.prediction_socket.send_string(data)
|
|
||||||
# time.sleep(1)
|
|
||||||
# print(prediction_dict)
|
|
||||||
# print(histories_dict)
|
|
||||||
# print(futures_dict)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_inference_server(config):
|
|
||||||
s = InferenceServer(config)
|
|
||||||
s.run()
|
|
|
@ -1,162 +0,0 @@
|
||||||
|
|
||||||
from argparse import Namespace
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from typing import Set, Union, Dict, Any
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from urllib.error import HTTPError
|
|
||||||
import tornado.ioloop
|
|
||||||
import tornado.web
|
|
||||||
import tornado.websocket
|
|
||||||
import zmq
|
|
||||||
import zmq.asyncio
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("trajpred.forwarder")
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler):
|
|
||||||
def initialize(self, zmq_socket: zmq.asyncio.Socket):
|
|
||||||
self.zmq_socket = zmq_socket
|
|
||||||
|
|
||||||
async def on_message(self, message):
|
|
||||||
logger.debug(f"recieve msg")
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.zmq_socket.send_string(message)
|
|
||||||
# msg = json.loads(message)
|
|
||||||
except Exception as e:
|
|
||||||
# self.send({'alert': 'Invalid request: {}'.format(e)})
|
|
||||||
logger.exception(e)
|
|
||||||
# self.write_message(u"You said: " + message)
|
|
||||||
|
|
||||||
def open(self, p=None):
|
|
||||||
logger.info(f"connected {self.request.remote_ip}")
|
|
||||||
|
|
||||||
# client disconnected
|
|
||||||
def on_close(self):
|
|
||||||
logger.info(f"Client disconnected: {self.request.remote_ip}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketPredictionHandler(tornado.websocket.WebSocketHandler):
|
|
||||||
connections: Set[Self] = set()
|
|
||||||
|
|
||||||
def initialize(self, config):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def on_message(self, message):
|
|
||||||
logger.warning(f"Receiving message on send-only ws handler: {message}")
|
|
||||||
|
|
||||||
def open(self, p=None):
|
|
||||||
logger.info(f"Prediction WS connected {self.request.remote_ip}")
|
|
||||||
self.__class__.connections.add(self)
|
|
||||||
|
|
||||||
# client disconnected
|
|
||||||
def on_close(self):
|
|
||||||
self.__class__.rmConnection(self)
|
|
||||||
|
|
||||||
logger.info(f"Client disconnected: {self.request.remote_ip}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def rmConnection(cls, client):
|
|
||||||
if client not in cls.connections:
|
|
||||||
return
|
|
||||||
cls.connections.remove(client)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def hasConnection(cls, client):
|
|
||||||
return client in cls.connections
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def write_to_clients(cls, msg: Union[bytes, str, Dict[str, Any]]):
|
|
||||||
if msg is None:
|
|
||||||
logger.critical("Tried to send 'none'")
|
|
||||||
return
|
|
||||||
|
|
||||||
toRemove = []
|
|
||||||
for client in cls.connections:
|
|
||||||
try:
|
|
||||||
client.write_message(msg)
|
|
||||||
except tornado.websocket.WebSocketClosedError as e:
|
|
||||||
logger.warning(f"Not properly closed websocket connection")
|
|
||||||
toRemove.append(client) # If we remove it here from the set we get an exception about changing set size during iteration
|
|
||||||
|
|
||||||
for client in toRemove:
|
|
||||||
cls.rmConnection(client)
|
|
||||||
|
|
||||||
class DemoHandler(tornado.web.RequestHandler):
|
|
||||||
def initialize(self, config: Namespace):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
self.render("index.html", ws_port=self.config.ws_port)
|
|
||||||
|
|
||||||
class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler):
|
|
||||||
def set_extra_headers(self, path):
|
|
||||||
"""For subclass to add extra headers to the response"""
|
|
||||||
if path[-5:] == ".html":
|
|
||||||
self.set_header("Access-Control-Allow-Origin", "*")
|
|
||||||
if path[-4:] == ".svg":
|
|
||||||
self.set_header("Content-Type", "image/svg+xml")
|
|
||||||
|
|
||||||
|
|
||||||
class WsRouter:
|
|
||||||
def __init__(self, config: Namespace):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
context = zmq.asyncio.Context()
|
|
||||||
self.trajectory_socket = context.socket(zmq.PUB)
|
|
||||||
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
|
|
||||||
|
|
||||||
self.prediction_socket = context.socket(zmq.SUB)
|
|
||||||
self.prediction_socket.connect(config.zmq_prediction_addr)
|
|
||||||
self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
|
||||||
|
|
||||||
self.application = tornado.web.Application(
|
|
||||||
[
|
|
||||||
(
|
|
||||||
r"/ws/trajectory",
|
|
||||||
WebSocketTrajectoryHandler,
|
|
||||||
{
|
|
||||||
"zmq_socket": self.trajectory_socket
|
|
||||||
},
|
|
||||||
),
|
|
||||||
(
|
|
||||||
r"/ws/prediction",
|
|
||||||
WebSocketPredictionHandler,
|
|
||||||
{
|
|
||||||
"config": config,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
(r"/", DemoHandler, {"config": config}),
|
|
||||||
# (r"/(.*)", StaticFileWithHeaderHandler, {"config": config, "index": 'index.html'}),
|
|
||||||
],
|
|
||||||
template_path = 'trajpred/web/',
|
|
||||||
compiled_template_cache=False)
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
|
|
||||||
evt_loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(evt_loop)
|
|
||||||
|
|
||||||
# loop = tornado.ioloop.IOLoop.current()
|
|
||||||
logger.info(f"Listen on {self.config.ws_port}")
|
|
||||||
self.application.listen(self.config.ws_port)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
task = evt_loop.create_task(self.prediction_forwarder())
|
|
||||||
|
|
||||||
evt_loop.run_forever()
|
|
||||||
|
|
||||||
async def prediction_forwarder(self):
|
|
||||||
logger.info("Starting prediction forwarder")
|
|
||||||
while True:
|
|
||||||
msg = await self.prediction_socket.recv_string()
|
|
||||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
|
||||||
WebSocketPredictionHandler.write_to_clients(msg)
|
|
||||||
|
|
||||||
def run_ws_forwarder(config: Namespace):
|
|
||||||
router = WsRouter(config)
|
|
||||||
router.start()
|
|
|
@ -1,170 +0,0 @@
|
||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
<title>Trajectory Prediction Browser Test</title>
|
|
||||||
<style>
|
|
||||||
body {
|
|
||||||
background: black;
|
|
||||||
}
|
|
||||||
|
|
||||||
#field {
|
|
||||||
background: white;
|
|
||||||
width: 100%;
|
|
||||||
height: 100%;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
|
|
||||||
<body>
|
|
||||||
<canvas id="field" width="1500" height="1500">
|
|
||||||
|
|
||||||
</canvas>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
// minified https://github.com/joewalnes/reconnecting-websocket
|
|
||||||
!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a});
|
|
||||||
</script>
|
|
||||||
|
|
||||||
<script>
|
|
||||||
// map the field to coordinates of our dummy tracker
|
|
||||||
const field_range = { x: [-10, 10], y: [-10, 10] }
|
|
||||||
|
|
||||||
// Create WebSocket connection.
|
|
||||||
const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`);
|
|
||||||
const prediction_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/prediction`);
|
|
||||||
let is_moving = false;
|
|
||||||
const fieldEl = document.getElementById('field');
|
|
||||||
|
|
||||||
let current_data = {}
|
|
||||||
// Listen for messages
|
|
||||||
prediction_socket.addEventListener("message", (event) => {
|
|
||||||
// console.log("Message from server ", event.data);
|
|
||||||
current_data = JSON.parse(event.data);
|
|
||||||
});
|
|
||||||
|
|
||||||
function getMousePos(canvas, evt) {
|
|
||||||
const rect = canvas.getBoundingClientRect();
|
|
||||||
return {
|
|
||||||
x: evt.clientX - rect.left,
|
|
||||||
y: evt.clientY - rect.top
|
|
||||||
};
|
|
||||||
}
|
|
||||||
function mouse_coordinates_to_position(coordinates) {
|
|
||||||
const x_range = field_range.x[1] - field_range.x[0]
|
|
||||||
const x = (coordinates.x / fieldEl.clientWidth) * x_range + field_range.x[0]
|
|
||||||
const y_range = field_range.y[1] - field_range.y[0]
|
|
||||||
const y = (coordinates.y / fieldEl.clientWidth) * y_range + field_range.y[0]
|
|
||||||
return { x: x, y: y }
|
|
||||||
}
|
|
||||||
function position_to_canvas_coordinate(position) {
|
|
||||||
const x_range = field_range.x[1] - field_range.x[0]
|
|
||||||
const y_range = field_range.y[1] - field_range.y[0]
|
|
||||||
|
|
||||||
const x = Array.isArray(position) ? position[0] : position.x;
|
|
||||||
const y = Array.isArray(position) ? position[1] : position.y;
|
|
||||||
return {
|
|
||||||
x: (x - field_range.x[0]) * fieldEl.width / x_range,
|
|
||||||
y: (y - field_range.y[0]) * fieldEl.width / y_range,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// helper function so we can spread
|
|
||||||
function coord_as_list(coord) {
|
|
||||||
return [coord.x, coord.y]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
let tracker = {}
|
|
||||||
let person_counter = 0
|
|
||||||
|
|
||||||
class Person {
|
|
||||||
constructor(id) {
|
|
||||||
this.id = id;
|
|
||||||
this.history = [];
|
|
||||||
this.prediction = []
|
|
||||||
}
|
|
||||||
|
|
||||||
addToHistory(position) {
|
|
||||||
this.history.push(position);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fieldEl.addEventListener('mousedown', (event) => {
|
|
||||||
person_counter++;
|
|
||||||
tracker[person_counter] = new Person(person_counter);
|
|
||||||
is_moving = true;
|
|
||||||
|
|
||||||
const mousePos = getMousePos(fieldEl, event);
|
|
||||||
const position = mouse_coordinates_to_position(mousePos)
|
|
||||||
tracker[person_counter].addToHistory(position);
|
|
||||||
trajectory_socket.send(JSON.stringify(tracker))
|
|
||||||
});
|
|
||||||
fieldEl.addEventListener('mousemove', (event) => {
|
|
||||||
if (!is_moving) return;
|
|
||||||
const mousePos = getMousePos(fieldEl, event);
|
|
||||||
const position = mouse_coordinates_to_position(mousePos)
|
|
||||||
tracker[person_counter].addToHistory(position);
|
|
||||||
trajectory_socket.send(JSON.stringify(tracker))
|
|
||||||
});
|
|
||||||
document.addEventListener('mouseup', (e) => {
|
|
||||||
is_moving = false;
|
|
||||||
tracker = {}
|
|
||||||
})
|
|
||||||
|
|
||||||
const ctx = fieldEl.getContext("2d");
|
|
||||||
function drawFrame() {
|
|
||||||
ctx.clearRect(0, 0, fieldEl.width, fieldEl.height);
|
|
||||||
ctx.save();
|
|
||||||
|
|
||||||
for (let id in current_data) {
|
|
||||||
const person = current_data[id];
|
|
||||||
if (person.history.length > 1) {
|
|
||||||
const hist = structuredClone(person.history)
|
|
||||||
// draw current position:
|
|
||||||
ctx.beginPath()
|
|
||||||
ctx.arc(
|
|
||||||
...coord_as_list(position_to_canvas_coordinate(hist[hist.length - 1])),
|
|
||||||
5, //radius
|
|
||||||
0, 2 * Math.PI);
|
|
||||||
ctx.fill()
|
|
||||||
|
|
||||||
ctx.beginPath()
|
|
||||||
ctx.lineWidth = 3;
|
|
||||||
ctx.strokeStyle = "#325FA2";
|
|
||||||
|
|
||||||
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(hist.shift())));
|
|
||||||
for (const position of hist) {
|
|
||||||
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
|
|
||||||
}
|
|
||||||
ctx.stroke();
|
|
||||||
}
|
|
||||||
|
|
||||||
if(person.hasOwnProperty('predictions') && person.predictions.length > 0) {
|
|
||||||
// multiple predictions can be sampled
|
|
||||||
person.predictions.forEach((prediction, i) => {
|
|
||||||
ctx.beginPath()
|
|
||||||
ctx.lineWidth = i === 1 ? 3 : 0.2;
|
|
||||||
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
|
|
||||||
|
|
||||||
// start from current position:
|
|
||||||
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
|
|
||||||
for (const position of prediction) {
|
|
||||||
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
|
|
||||||
}
|
|
||||||
ctx.stroke();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx.restore();
|
|
||||||
|
|
||||||
window.requestAnimationFrame(drawFrame);
|
|
||||||
}
|
|
||||||
|
|
||||||
window.requestAnimationFrame(drawFrame);
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
|
|
||||||
</html>
|
|
Loading…
Reference in a new issue