Trajectory prediction - test in browser
This commit is contained in:
		
						commit
						7c06913d88
					
				
					 12 changed files with 4305 additions and 0 deletions
				
			
		
							
								
								
									
										1
									
								
								.python-version
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								.python-version
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1 @@ | |||
| 3.10.4 | ||||
							
								
								
									
										3229
									
								
								poetry.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
							
						
						
									
										3229
									
								
								poetry.lock
									
									
									
										generated
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							
							
								
								
									
										15
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,15 @@ | |||
| [tool.poetry] | ||||
| name = "trap" | ||||
| version = "0.1.0" | ||||
| description = "Art installation with trajectory prediction" | ||||
| authors = ["Ruben van de Ven <git@rubenvandeven.com>"] | ||||
| readme = "README.md" | ||||
| 
 | ||||
| [tool.poetry.dependencies] | ||||
| python = "^3.10,<3.12," | ||||
| 
 | ||||
| trajectron-plus-plus = { path = "../Trajectron-plus-plus/", develop = true } | ||||
| 
 | ||||
| [build-system] | ||||
| requires = ["poetry-core"] | ||||
| build-backend = "poetry.core.masonry.api" | ||||
							
								
								
									
										5
									
								
								run_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								run_server.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,5 @@ | |||
| from trap import plumber | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     plumber.start() | ||||
| 
 | ||||
							
								
								
									
										0
									
								
								trap/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								trap/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										168
									
								
								trap/config.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								trap/config.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,168 @@ | |||
| import argparse | ||||
| from pathlib import Path | ||||
| 
 | ||||
| parser = argparse.ArgumentParser() | ||||
| 
 | ||||
| 
 | ||||
| parser.add_argument( | ||||
|     '--verbose', | ||||
|     '-v', | ||||
|     help="Increase verbosity. Add multiple times to increase further.", | ||||
|     action='count', default=0 | ||||
| ) | ||||
| parser.add_argument( | ||||
|     '--remote-log-addr', | ||||
|     help="Connect to a remote logger like cutelog. Specify the ip", | ||||
|     type=str, | ||||
| ) | ||||
| parser.add_argument( | ||||
|     '--remote-log-port', | ||||
|     help="Connect to a remote logger like cutelog. Specify the port", | ||||
|     type=int, | ||||
|     default=19996 | ||||
| ) | ||||
| 
 | ||||
| # parser.add_argument('--foo') | ||||
| inference_parser = parser.add_argument_group('inference server') | ||||
| connection_parser = parser.add_argument_group('connection') | ||||
| frame_emitter_parser = parser.add_argument_group('Frame emitter') | ||||
| 
 | ||||
| inference_parser.add_argument("--model_dir", | ||||
|                     help="directory with the model to use for inference", | ||||
|                     type=str, # TODO: make into Path | ||||
|                     default='../Trajectron-plus-plus/experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3') | ||||
| 
 | ||||
| inference_parser.add_argument("--conf", | ||||
|                     help="path to json config file for hyperparameters, relative to model_dir", | ||||
|                     type=str, | ||||
|                     default='config.json') | ||||
| 
 | ||||
| # Model Parameters (hyperparameters) | ||||
| inference_parser.add_argument("--offline_scene_graph", | ||||
|                     help="whether to precompute the scene graphs offline, options are 'no' and 'yes'", | ||||
|                     type=str, | ||||
|                     default='yes') | ||||
| 
 | ||||
| inference_parser.add_argument("--dynamic_edges", | ||||
|                     help="whether to use dynamic edges or not, options are 'no' and 'yes'", | ||||
|                     type=str, | ||||
|                     default='yes') | ||||
| 
 | ||||
| inference_parser.add_argument("--edge_state_combine_method", | ||||
|                     help="the method to use for combining edges of the same type", | ||||
|                     type=str, | ||||
|                     default='sum') | ||||
| 
 | ||||
| inference_parser.add_argument("--edge_influence_combine_method", | ||||
|                     help="the method to use for combining edge influences", | ||||
|                     type=str, | ||||
|                     default='attention') | ||||
| 
 | ||||
| inference_parser.add_argument('--edge_addition_filter', | ||||
|                     nargs='+', | ||||
|                     help="what scaling to use for edges as they're created", | ||||
|                     type=float, | ||||
|                     default=[0.25, 0.5, 0.75, 1.0]) # We don't automatically pad left with 0.0, if you want a sharp | ||||
|                                                     # and short edge addition, then you need to have a 0.0 at the | ||||
|                                                     # beginning, e.g. [0.0, 1.0]. | ||||
| 
 | ||||
| inference_parser.add_argument('--edge_removal_filter', | ||||
|                     nargs='+', | ||||
|                     help="what scaling to use for edges as they're removed", | ||||
|                     type=float, | ||||
|                     default=[1.0, 0.0])  # We don't automatically pad right with 0.0, if you want a sharp drop off like | ||||
|                                          # the default, then you need to have a 0.0 at the end. | ||||
| 
 | ||||
| 
 | ||||
| inference_parser.add_argument('--incl_robot_node', | ||||
|                     help="whether to include a robot node in the graph or simply model all agents", | ||||
|                     action='store_true') | ||||
| 
 | ||||
| inference_parser.add_argument('--map_encoding', | ||||
|                     help="Whether to use map encoding or not", | ||||
|                     action='store_true') | ||||
| 
 | ||||
| inference_parser.add_argument('--no_edge_encoding', | ||||
|                     help="Whether to use neighbors edge encoding", | ||||
|                     action='store_true') | ||||
| 
 | ||||
| 
 | ||||
| inference_parser.add_argument('--batch_size', | ||||
|                     help='training batch size', | ||||
|                     type=int, | ||||
|                     default=256) | ||||
| 
 | ||||
| inference_parser.add_argument('--k_eval', | ||||
|                     help='how many samples to take during evaluation', | ||||
|                     type=int, | ||||
|                     default=25) | ||||
| 
 | ||||
| # Data Parameters | ||||
| inference_parser.add_argument("--eval_data_dict", | ||||
|                     help="what file to load for evaluation data (WHEN NOT USING LIVE DATA)", | ||||
|                     type=str, | ||||
|                     default='../Trajectron-plus-plus/experiments/processed/eth_test.pkl') | ||||
| 
 | ||||
| inference_parser.add_argument("--output_dir", | ||||
|                     help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)", | ||||
|                     type=str, | ||||
|                     default='./OUT/test_inference') | ||||
| 
 | ||||
| 
 | ||||
| # inference_parser.add_argument('--device', | ||||
| #                     help='what device to perform training on', | ||||
| #                     type=str, | ||||
| #                     default='cuda:0') | ||||
| 
 | ||||
| inference_parser.add_argument("--eval_device", | ||||
|                     help="what device to use during inference", | ||||
|                     type=str, | ||||
|                     default="cpu") | ||||
| 
 | ||||
| 
 | ||||
| inference_parser.add_argument('--seed', | ||||
|                     help='manual seed to use, default is 123', | ||||
|                     type=int, | ||||
|                     default=123) | ||||
| 
 | ||||
| 
 | ||||
| # Internal connections. | ||||
| 
 | ||||
| connection_parser.add_argument('--zmq-trajectory-addr', | ||||
|                     help='Manually specity communication addr for the trajectory messages', | ||||
|                     type=str, | ||||
|                     default="ipc:///tmp/feeds/traj") | ||||
| 
 | ||||
| connection_parser.add_argument('--zmq-camera-stream-addr', | ||||
|                     help='Manually specity communication addr for the camera stream messages', | ||||
|                     type=str, | ||||
|                     default="ipc:///tmp/feeds/img") | ||||
| 
 | ||||
| connection_parser.add_argument('--zmq-prediction-addr', | ||||
|                     help='Manually specity communication addr for the prediction messages', | ||||
|                     type=str, | ||||
|                     default="ipc:///tmp/feeds/preds") | ||||
| 
 | ||||
| connection_parser.add_argument('--zmq-frame-addr', | ||||
|                     help='Manually specity communication addr for the frame messages', | ||||
|                     type=str, | ||||
|                     default="ipc:///tmp/feeds/frame") | ||||
| 
 | ||||
| 
 | ||||
| connection_parser.add_argument('--ws-port', | ||||
|                     help='Port to listen for incomming websocket connections. Also serves the testing html-page.', | ||||
|                     type=int, | ||||
|                     default=8888) | ||||
| 
 | ||||
| connection_parser.add_argument('--bypass-prediction', | ||||
|                     help='For debugging purpose: websocket input immediately to output', | ||||
|                     action='store_true') | ||||
| 
 | ||||
| # Frame emitter | ||||
| 
 | ||||
| frame_emitter_parser.add_argument("--video-src", | ||||
|                     help="source video to track from", | ||||
|                     type=Path, | ||||
|                     default='../DATASETS/VIRAT_subset_0102x/VIRAT_S_010200_00_000060_000218.mp4') | ||||
| 
 | ||||
| #TODO: camera | ||||
							
								
								
									
										51
									
								
								trap/frame_emitter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								trap/frame_emitter.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,51 @@ | |||
| from argparse import Namespace | ||||
| import time | ||||
| 
 | ||||
| import cv2 | ||||
| import zmq | ||||
| 
 | ||||
| 
 | ||||
| class FrameEmitter: | ||||
|     ''' | ||||
|     Emit frame in a separate threat so they can be throttled, | ||||
|     or thrown away when the rest of the system cannot keep up | ||||
|     ''' | ||||
|     def __init__(self, config: Namespace) -> None: | ||||
|         self.config = config | ||||
| 
 | ||||
|         context = zmq.Context() | ||||
|         self.frame_sock = context.socket(zmq.PUB) | ||||
|         self.frame_sock.bind(config.zmq_frame_addr) | ||||
|         self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame | ||||
| 
 | ||||
|     def emit_video(self): | ||||
|         video = cv2.VideoCapture(str(self.config.video_src)) | ||||
|         fps = video.get(cv2.CAP_PROP_FPS) | ||||
|         frame_duration = 1./fps | ||||
| 
 | ||||
|         prev_time = time.time() | ||||
|         while True: | ||||
|             ret, frame = video.read() | ||||
| 
 | ||||
|             # seek to 0 if video has finished. Infinite loop | ||||
|             if not ret: | ||||
|                 video.set(cv2.CAP_PROP_POS_FRAMES, 0) | ||||
|                 ret, frame = video.read() | ||||
|                 assert ret is not False # not really error proof... | ||||
|              | ||||
|             self.frame_sock.send(frame) | ||||
| 
 | ||||
|             # defer next loop | ||||
|             new_frame_time = time.time() | ||||
|             time_diff = (new_frame_time - prev_time) | ||||
|             if time_diff < frame_duration: | ||||
|                 time.sleep(frame_duration - time_diff) | ||||
|                 new_frame_time += frame_duration - time_diff | ||||
|             else: | ||||
|                 prev_time = new_frame_time | ||||
| 
 | ||||
|              | ||||
| 
 | ||||
| def run_frame_emitter(config: Namespace): | ||||
|     router = FrameEmitter(config) | ||||
|     router.emit_video() | ||||
							
								
								
									
										42
									
								
								trap/plumber.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								trap/plumber.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,42 @@ | |||
| import logging | ||||
| from logging.handlers import SocketHandler | ||||
| from multiprocessing import Process, Queue | ||||
| from trap.config import parser | ||||
| from trap.frame_emitter import run_frame_emitter | ||||
| from trap.prediction_server import InferenceServer, run_inference_server | ||||
| from trap.socket_forwarder import run_ws_forwarder | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger("trap.plumbing") | ||||
| 
 | ||||
| def start(): | ||||
|     args = parser.parse_args() | ||||
|     loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0  else logging.INFO | ||||
|     logging.basicConfig( | ||||
|         level=loglevel, | ||||
|     ) | ||||
| 
 | ||||
|     if args.remote_log_addr: | ||||
|         logging.captureWarnings(True) | ||||
|         root_logger = logging.getLogger() | ||||
|         root_logger.setLevel(logging.NOTSET)  # to send all records to cutelog | ||||
|         socket_handler = SocketHandler(args.remote_log_addr, args.remote_log_port) | ||||
|         root_logger.addHandler(socket_handler) | ||||
| 
 | ||||
|     # instantiating process with arguments | ||||
|     procs = [ | ||||
|         # Process(target=run_ws_forwarder, args=(args,)), | ||||
|         Process(target=run_frame_emitter, args=(args,)), | ||||
|     ] | ||||
|     if not args.bypass_prediction: | ||||
|         procs.append( | ||||
|             Process(target=run_inference_server, args=(args,)), | ||||
|         ) | ||||
| 
 | ||||
|     logger.info("start") | ||||
|     for proc in procs: | ||||
|         proc.start() | ||||
| 
 | ||||
|     for proc in procs: | ||||
|         proc.join() | ||||
| 
 | ||||
							
								
								
									
										352
									
								
								trap/prediction_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										352
									
								
								trap/prediction_server.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,352 @@ | |||
| # adapted from Trajectron++ online_server.py | ||||
| import logging | ||||
| from multiprocessing import Queue | ||||
| import os | ||||
| import time | ||||
| import json | ||||
| import pandas as pd | ||||
| import torch | ||||
| import dill | ||||
| import random | ||||
| import pathlib | ||||
| import numpy as np | ||||
| from trajectron.environment.data_utils import derivative_of | ||||
| from trajectron.utils import prediction_output_to_trajectories | ||||
| from trajectron.model.online.online_trajectron import OnlineTrajectron | ||||
| from trajectron.model.model_registrar import ModelRegistrar | ||||
| from trajectron.environment import Environment, Scene | ||||
| from trajectron.environment.node import Node | ||||
| from trajectron.environment.node_type import NodeType | ||||
| import matplotlib.pyplot as plt | ||||
| 
 | ||||
| import zmq | ||||
| 
 | ||||
| logger = logging.getLogger("trap.inference") | ||||
| 
 | ||||
| 
 | ||||
| # if not torch.cuda.is_available() or self.config.device == 'cpu': | ||||
| #     self.config.device = torch.device('cpu') | ||||
| # else: | ||||
| #     if torch.cuda.device_count() == 1: | ||||
| #         # If you have CUDA_VISIBLE_DEVICES set, which you should, | ||||
| #         # then this will prevent leftover flag arguments from | ||||
| #         # messing with the device allocation. | ||||
| #         self.config.device = 'cuda:0' | ||||
| 
 | ||||
| #     self.config.device = torch.device(self.config.device) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def create_online_env(env, hyperparams, scene_idx, init_timestep): | ||||
|     test_scene = env.scenes[scene_idx] | ||||
| 
 | ||||
|     online_scene = Scene(timesteps=init_timestep + 1, | ||||
|                          map=test_scene.map, | ||||
|                          dt=test_scene.dt) | ||||
|     online_scene.nodes = test_scene.get_nodes_clipped_at_time( | ||||
|         timesteps=np.arange(init_timestep - hyperparams['maximum_history_length'], | ||||
|                             init_timestep + 1), | ||||
|         state=hyperparams['state']) | ||||
|     online_scene.robot = test_scene.robot | ||||
|     online_scene.calculate_scene_graph(attention_radius=env.attention_radius, | ||||
|                                        edge_addition_filter=hyperparams['edge_addition_filter'], | ||||
|                                        edge_removal_filter=hyperparams['edge_removal_filter']) | ||||
| 
 | ||||
|     return Environment(node_type_list=env.node_type_list, | ||||
|                        standardization=env.standardization, | ||||
|                        scenes=[online_scene], | ||||
|                        attention_radius=env.attention_radius, | ||||
|                        robot_type=env.robot_type) | ||||
| 
 | ||||
| 
 | ||||
| def get_maps_for_input(input_dict, scene, hyperparams): | ||||
|     scene_maps = list() | ||||
|     scene_pts = list() | ||||
|     heading_angles = list() | ||||
|     patch_sizes = list() | ||||
|     nodes_with_maps = list() | ||||
|     for node in input_dict: | ||||
|         if node.type in hyperparams['map_encoder']: | ||||
|             x = input_dict[node] | ||||
|             me_hyp = hyperparams['map_encoder'][node.type] | ||||
|             if 'heading_state_index' in me_hyp: | ||||
|                 heading_state_index = me_hyp['heading_state_index'] | ||||
|                 # We have to rotate the map in the opposit direction of the agent to match them | ||||
|                 if type(heading_state_index) is list:  # infer from velocity or heading vector | ||||
|                     heading_angle = -np.arctan2(x[-1, heading_state_index[1]], | ||||
|                                                 x[-1, heading_state_index[0]]) * 180 / np.pi | ||||
|                 else: | ||||
|                     heading_angle = -x[-1, heading_state_index] * 180 / np.pi | ||||
|             else: | ||||
|                 heading_angle = None | ||||
| 
 | ||||
|             scene_map = scene.map[node.type] | ||||
|             map_point = x[-1, :2] | ||||
| 
 | ||||
|             patch_size = hyperparams['map_encoder'][node.type]['patch_size'] | ||||
| 
 | ||||
|             scene_maps.append(scene_map) | ||||
|             scene_pts.append(map_point) | ||||
|             heading_angles.append(heading_angle) | ||||
|             patch_sizes.append(patch_size) | ||||
|             nodes_with_maps.append(node) | ||||
| 
 | ||||
|     if heading_angles[0] is None: | ||||
|         heading_angles = None | ||||
|     else: | ||||
|         heading_angles = torch.Tensor(heading_angles) | ||||
| 
 | ||||
|     maps = scene_maps[0].get_cropped_maps_from_scene_map_batch(scene_maps, | ||||
|                                                                scene_pts=torch.Tensor(scene_pts), | ||||
|                                                                patch_size=patch_sizes[0], | ||||
|                                                                rotation=heading_angles) | ||||
| 
 | ||||
|     maps_dict = {node: maps[[i]] for i, node in enumerate(nodes_with_maps)} | ||||
|     return maps_dict | ||||
| 
 | ||||
| 
 | ||||
| class InferenceServer: | ||||
|     def __init__(self, config: dict): | ||||
|         self.config = config | ||||
|          | ||||
|         context = zmq.Context() | ||||
|         self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB) | ||||
|         self.trajectory_socket.connect(config.zmq_trajectory_addr) | ||||
|         self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'') | ||||
|         self.trajectory_socket.setsockopt(zmq.CONFLATE, 1) # only keep last msg | ||||
|          | ||||
|         self.prediction_socket: zmq.Socket = context.socket(zmq.PUB) | ||||
|         self.prediction_socket.bind(config.zmq_prediction_addr) | ||||
|         print(self.prediction_socket) | ||||
| 
 | ||||
|     def run(self): | ||||
| 
 | ||||
|         if self.config.seed is not None: | ||||
|             random.seed(self.config.seed) | ||||
|             np.random.seed(self.config.seed) | ||||
|             torch.manual_seed(self.config.seed) | ||||
|             if torch.cuda.is_available(): | ||||
|                 torch.cuda.manual_seed_all(self.config.seed) | ||||
|          | ||||
|         # Choose one of the model directory names under the experiment/*/models folders. | ||||
|         # Possibilities are 'vel_ee', 'int_ee', 'int_ee_me', or 'robot' | ||||
|         # model_dir = os.path.join(self.config.log_dir, 'int_ee') | ||||
|         # model_dir = 'models/models_04_Oct_2023_21_04_48_eth_vel_ar3' | ||||
| 
 | ||||
|         # Load hyperparameters from json | ||||
|         config_file = os.path.join(self.config.model_dir, self.config.conf) | ||||
|         if not os.path.exists(config_file): | ||||
|             raise ValueError('Config json not found!') | ||||
|         with open(config_file, 'r') as conf_json: | ||||
|             hyperparams = json.load(conf_json) | ||||
| 
 | ||||
|         # Add hyperparams from arguments | ||||
|         hyperparams['dynamic_edges'] = self.config.dynamic_edges | ||||
|         hyperparams['edge_state_combine_method'] = self.config.edge_state_combine_method | ||||
|         hyperparams['edge_influence_combine_method'] = self.config.edge_influence_combine_method | ||||
|         hyperparams['edge_addition_filter'] = self.config.edge_addition_filter | ||||
|         hyperparams['edge_removal_filter'] = self.config.edge_removal_filter | ||||
|         hyperparams['batch_size'] = self.config.batch_size | ||||
|         hyperparams['k_eval'] = self.config.k_eval | ||||
|         hyperparams['offline_scene_graph'] = self.config.offline_scene_graph | ||||
|         hyperparams['incl_robot_node'] = self.config.incl_robot_node | ||||
|         hyperparams['edge_encoding'] = not self.config.no_edge_encoding | ||||
|         hyperparams['use_map_encoding'] = self.config.map_encoding | ||||
|         # hyperparams['maximum_history_length'] = 12 # test | ||||
| 
 | ||||
|         logger.info(f"Use hyperparams: {hyperparams=}") | ||||
| 
 | ||||
|         output_save_dir = os.path.join(self.config.output_dir, 'pred_figs') | ||||
|         pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True) | ||||
| 
 | ||||
| 
 | ||||
|         with open(self.config.eval_data_dict, 'rb') as f: | ||||
|             eval_env = dill.load(f, encoding='latin1') | ||||
| 
 | ||||
|         if eval_env.robot_type is None and hyperparams['incl_robot_node']: | ||||
|             eval_env.robot_type = eval_env.NodeType[0]  # TODO: Make more general, allow the user to specify? | ||||
|             for scene in eval_env.scenes: | ||||
|                 scene.add_robot_from_nodes(eval_env.robot_type) | ||||
| 
 | ||||
|         logger.info('Loaded data from %s' % (self.config.eval_data_dict,)) | ||||
| 
 | ||||
|         # Creating a dummy environment with a single scene that contains information about the world. | ||||
|         # When using this code, feel free to use whichever scene index or initial timestep you wish. | ||||
|         scene_idx = 0 | ||||
| 
 | ||||
|         # You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1], | ||||
|         # so that you can immediately start incremental inference from the 3rd timestep onwards. | ||||
|         init_timestep = 1 | ||||
| 
 | ||||
|         eval_scene = eval_env.scenes[scene_idx] | ||||
|         online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep) | ||||
| 
 | ||||
|         model_registrar = ModelRegistrar(self.config.model_dir, self.config.eval_device) | ||||
|         model_registrar.load_models(iter_num=100) | ||||
| 
 | ||||
|         trajectron = OnlineTrajectron(model_registrar, | ||||
|                                     hyperparams, | ||||
|                                     self.config.eval_device) | ||||
| 
 | ||||
|         # If you want to see what different robot futures do to the predictions, uncomment this line as well as | ||||
|         # related "... += adjustment" lines below. | ||||
|         # adjustment = np.stack([np.arange(13)/float(i*2.0) for i in range(6, 12)], axis=1) | ||||
| 
 | ||||
|         # Here's how you'd incrementally run the model, e.g. with streaming data. | ||||
|         trajectron.set_environment(online_env, init_timestep) | ||||
| 
 | ||||
|         timestep = init_timestep + 1 | ||||
|         while True: | ||||
|             timestep += 1 | ||||
|             # for timestep in range(init_timestep + 1, eval_scene.timesteps): | ||||
| 
 | ||||
|             # input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state']) | ||||
|             # TODO: see process_data.py on how to create a node, the provide nodes + incoming data columns | ||||
|             # data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']]) | ||||
|             # x = node_values[:, 0] | ||||
|             # y = node_values[:, 1] | ||||
|             # vx = derivative_of(x, scene.dt) | ||||
|             # vy = derivative_of(y, scene.dt) | ||||
|             # ax = derivative_of(vx, scene.dt) | ||||
|             # ay = derivative_of(vy, scene.dt) | ||||
| 
 | ||||
|             # data_dict = {('position', 'x'): x, | ||||
|             #              ('position', 'y'): y, | ||||
|             #              ('velocity', 'x'): vx, | ||||
|             #              ('velocity', 'y'): vy, | ||||
|             #              ('acceleration', 'x'): ax, | ||||
|             #              ('acceleration', 'y'): ay} | ||||
| 
 | ||||
|             # node_data = pd.DataFrame(data_dict, columns=data_columns) | ||||
|             # node = Node(node_type=env.NodeType.PEDESTRIAN, node_id=node_id, data=node_data) | ||||
|              | ||||
|             data = self.trajectory_socket.recv_string() | ||||
|             trajectory_data = json.loads(data) | ||||
|             logger.info(f"Receive {trajectory_data}") | ||||
| 
 | ||||
|             # class FakeNode: | ||||
|             #     def __init__(self, node_type: NodeType): | ||||
|             #         self.type = node_type | ||||
| 
 | ||||
|             input_dict = {} | ||||
|             for identifier, trajectory in trajectory_data.items(): | ||||
|                 # if len(trajectory['history']) < 7: | ||||
|                 #     # TODO: these trajectories should still be in the output, but without predictions | ||||
|                 #     continue | ||||
| 
 | ||||
|                 # TODO: modify this into a mapping function between JS data an the expected Node format | ||||
|                 # node = FakeNode(online_env.NodeType.PEDESTRIAN) | ||||
|                 history = [[h['x'], h['y']] for h in trajectory['history']] | ||||
|                 history = np.array(history) | ||||
|                 x = history[:, 0] | ||||
|                 y = history[:, 1] | ||||
|                 # TODO: calculate dt based on input | ||||
|                 vx = derivative_of(x, 0.2) #eval_scene.dt | ||||
|                 vy = derivative_of(y, 0.2) | ||||
|                 ax = derivative_of(vx, 0.2) | ||||
|                 ay = derivative_of(vy, 0.2) | ||||
| 
 | ||||
|                 data_dict = {('position', 'x'): x[:], | ||||
|                             ('position', 'y'): y[:], | ||||
|                             ('velocity', 'x'): vx[:], | ||||
|                             ('velocity', 'y'): vy[:], | ||||
|                             ('acceleration', 'x'): ax[:], | ||||
|                             ('acceleration', 'y'): ay[:]} | ||||
|                 data_columns = pd.MultiIndex.from_product([['position', 'velocity', 'acceleration'], ['x', 'y']]) | ||||
| 
 | ||||
|                 node_data = pd.DataFrame(data_dict, columns=data_columns) | ||||
|                 node = Node( | ||||
|                     node_type=online_env.NodeType.PEDESTRIAN, | ||||
|                     node_id=identifier, | ||||
|                     data=node_data, | ||||
|                     first_timestep=timestep | ||||
|                     ) | ||||
| 
 | ||||
|                 input_dict[node] = np.array([x[-1],y[-1],vx[-1],vy[-1],ax[-1],ay[-1]]) | ||||
| 
 | ||||
|             # print(input_dict) | ||||
| 
 | ||||
|             if not len(input_dict): | ||||
|                 # skip if our input is empty | ||||
|                 # TODO: we want to send out empty result... | ||||
| 
 | ||||
|                 data = json.dumps({}) | ||||
|                 self.prediction_socket.send_string(data) | ||||
|                 continue | ||||
| 
 | ||||
|             maps = None | ||||
|             if hyperparams['use_map_encoding']: | ||||
|                 maps = get_maps_for_input(input_dict, eval_scene, hyperparams) | ||||
|             # print(maps) | ||||
| 
 | ||||
|             robot_present_and_future = None | ||||
|             if eval_scene.robot is not None and hyperparams['incl_robot_node']: | ||||
|                 robot_present_and_future = eval_scene.robot.get(np.array([timestep, | ||||
|                                                                         timestep + hyperparams['prediction_horizon']]), | ||||
|                                                                 hyperparams['state'][eval_scene.robot.type], | ||||
|                                                                 padding=0.0) | ||||
|                 robot_present_and_future = np.stack([robot_present_and_future, robot_present_and_future], axis=0) | ||||
|                 # robot_present_and_future += adjustment | ||||
| 
 | ||||
|             start = time.time() | ||||
|             dists, preds = trajectron.incremental_forward(input_dict, | ||||
|                                                         maps, | ||||
|                                                         prediction_horizon=16, # TODO: make variable | ||||
|                                                         num_samples=3, # TODO: make variable | ||||
|                                                         robot_present_and_future=robot_present_and_future, | ||||
|                                                         full_dist=True) | ||||
|             end = time.time() | ||||
|             logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start, | ||||
|                                                                             1. / (end - start), len(trajectron.nodes), | ||||
|                                                                             trajectron.scene_graph.get_num_edges())) | ||||
| 
 | ||||
|             # unsure what this bit from online_prediction.py does: | ||||
|             # detailed_preds_dict = dict() | ||||
|             # for node in eval_scene.nodes: | ||||
|             #     if node in preds: | ||||
|             #         detailed_preds_dict[node] = preds[node] | ||||
| 
 | ||||
|             #adapted from trajectron.visualization | ||||
|             # prediction_dict provides the actual predictions | ||||
|             # histories_dict provides the trajectory used for prediction | ||||
|             # futures_dict is the Ground Truth, which is unvailable in an online setting | ||||
|             prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds}, | ||||
|                                                                                       eval_scene.dt, | ||||
|                                                                                       hyperparams['maximum_history_length'], | ||||
|                                                                                         hyperparams['prediction_horizon'] | ||||
|                                                                                         ) | ||||
| 
 | ||||
|             assert(len(prediction_dict.keys()) <= 1) | ||||
|             if len(prediction_dict.keys()) == 0: | ||||
|                 return | ||||
|             ts_key = list(prediction_dict.keys())[0] | ||||
| 
 | ||||
|             prediction_dict = prediction_dict[ts_key] | ||||
|             histories_dict = histories_dict[ts_key] | ||||
|             futures_dict = futures_dict[ts_key] | ||||
| 
 | ||||
|             response = {} | ||||
|             print(histories_dict) | ||||
|              | ||||
|             for node in histories_dict: | ||||
|                 history = histories_dict[node] | ||||
|                 # future = futures_dict[node] | ||||
|                 predictions = prediction_dict[node] | ||||
| 
 | ||||
|                 if not len(history) or np.isnan(history[-1]).any(): | ||||
|                     continue | ||||
| 
 | ||||
|                 response[node.id] = { | ||||
|                     'id': node.id, | ||||
|                     'history': history.tolist(), | ||||
|                     'predictions': predictions[0].tolist() # use batch 0 | ||||
|                 } | ||||
| 
 | ||||
|             data = json.dumps(response) | ||||
|             self.prediction_socket.send_string(data) | ||||
| 
 | ||||
|              | ||||
| 
 | ||||
| def run_inference_server(config): | ||||
|     s = InferenceServer(config) | ||||
|     s.run() | ||||
							
								
								
									
										162
									
								
								trap/socket_forwarder.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								trap/socket_forwarder.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,162 @@ | |||
| 
 | ||||
| from argparse import Namespace | ||||
| import asyncio | ||||
| import logging | ||||
| from typing import Set, Union, Dict, Any | ||||
| from typing_extensions import Self | ||||
| 
 | ||||
| from urllib.error import HTTPError | ||||
| import tornado.ioloop | ||||
| import tornado.web | ||||
| import tornado.websocket | ||||
| import zmq | ||||
| import zmq.asyncio | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger("trap.forwarder") | ||||
| 
 | ||||
| 
 | ||||
| class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler): | ||||
|     def initialize(self, zmq_socket: zmq.asyncio.Socket): | ||||
|         self.zmq_socket = zmq_socket | ||||
| 
 | ||||
|     async def on_message(self, message): | ||||
|         logger.debug(f"recieve msg") | ||||
| 
 | ||||
|         try: | ||||
|             await self.zmq_socket.send_string(message) | ||||
|             # msg = json.loads(message) | ||||
|         except Exception as e: | ||||
|             # self.send({'alert': 'Invalid request: {}'.format(e)}) | ||||
|             logger.exception(e) | ||||
|         # self.write_message(u"You said: " + message) | ||||
| 
 | ||||
|     def open(self, p=None): | ||||
|         logger.info(f"connected  {self.request.remote_ip}") | ||||
| 
 | ||||
|     # client disconnected | ||||
|     def on_close(self): | ||||
|         logger.info(f"Client disconnected: {self.request.remote_ip}") | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class WebSocketPredictionHandler(tornado.websocket.WebSocketHandler): | ||||
|     connections: Set[Self] = set() | ||||
| 
 | ||||
|     def initialize(self, config): | ||||
|         self.config = config | ||||
| 
 | ||||
|     def on_message(self, message): | ||||
|         logger.warning(f"Receiving message on send-only ws handler: {message}") | ||||
| 
 | ||||
|     def open(self, p=None): | ||||
|         logger.info(f"Prediction WS connected {self.request.remote_ip}") | ||||
|         self.__class__.connections.add(self) | ||||
| 
 | ||||
|     # client disconnected | ||||
|     def on_close(self): | ||||
|         self.__class__.rmConnection(self) | ||||
| 
 | ||||
|         logger.info(f"Client disconnected: {self.request.remote_ip}") | ||||
| 
 | ||||
|     @classmethod | ||||
|     def rmConnection(cls, client): | ||||
|         if client not in cls.connections: | ||||
|             return | ||||
|         cls.connections.remove(client) | ||||
| 
 | ||||
|     @classmethod | ||||
|     def hasConnection(cls, client): | ||||
|         return client in cls.connections | ||||
|      | ||||
|     @classmethod | ||||
|     def write_to_clients(cls, msg: Union[bytes, str, Dict[str, Any]]): | ||||
|         if msg is None: | ||||
|             logger.critical("Tried to send 'none'") | ||||
|             return | ||||
| 
 | ||||
|         toRemove = [] | ||||
|         for client in cls.connections: | ||||
|             try: | ||||
|                 client.write_message(msg) | ||||
|             except tornado.websocket.WebSocketClosedError as e: | ||||
|                 logger.warning(f"Not properly closed websocket connection") | ||||
|                 toRemove.append(client) # If we remove it here from the set we get an exception about changing set size during iteration | ||||
| 
 | ||||
|         for client in toRemove: | ||||
|             cls.rmConnection(client) | ||||
| 
 | ||||
| class DemoHandler(tornado.web.RequestHandler): | ||||
|     def initialize(self, config: Namespace): | ||||
|         self.config = config | ||||
| 
 | ||||
|     def get(self): | ||||
|         self.render("index.html", ws_port=self.config.ws_port) | ||||
| 
 | ||||
| class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler): | ||||
|     def set_extra_headers(self, path): | ||||
|         """For subclass to add extra headers to the response""" | ||||
|         if path[-5:] == ".html": | ||||
|             self.set_header("Access-Control-Allow-Origin", "*") | ||||
|         if path[-4:] == ".svg": | ||||
|             self.set_header("Content-Type", "image/svg+xml") | ||||
| 
 | ||||
| 
 | ||||
| class WsRouter: | ||||
|     def __init__(self, config: Namespace): | ||||
|         self.config = config | ||||
| 
 | ||||
|         context = zmq.asyncio.Context() | ||||
|         self.trajectory_socket = context.socket(zmq.PUB) | ||||
|         self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr) | ||||
|          | ||||
|         self.prediction_socket = context.socket(zmq.SUB) | ||||
|         self.prediction_socket.connect(config.zmq_prediction_addr) | ||||
|         self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'') | ||||
| 
 | ||||
|         self.application = tornado.web.Application( | ||||
|             [ | ||||
|                 ( | ||||
|                     r"/ws/trajectory", | ||||
|                     WebSocketTrajectoryHandler, | ||||
|                     { | ||||
|                         "zmq_socket": self.trajectory_socket | ||||
|                     }, | ||||
|                 ), | ||||
|                 ( | ||||
|                     r"/ws/prediction", | ||||
|                     WebSocketPredictionHandler, | ||||
|                     { | ||||
|                         "config": config, | ||||
|                     }, | ||||
|                 ), | ||||
|                 (r"/", DemoHandler, {"config": config}), | ||||
|                 # (r"/(.*)", StaticFileWithHeaderHandler, {"config": config, "index": 'index.html'}), | ||||
|             ], | ||||
|             template_path = 'trap/web/', | ||||
|             compiled_template_cache=False) | ||||
| 
 | ||||
|     def start(self): | ||||
| 
 | ||||
|         evt_loop = asyncio.new_event_loop() | ||||
|         asyncio.set_event_loop(evt_loop) | ||||
| 
 | ||||
|         # loop = tornado.ioloop.IOLoop.current() | ||||
|         logger.info(f"Listen on {self.config.ws_port}") | ||||
|         self.application.listen(self.config.ws_port) | ||||
|         loop = asyncio.get_event_loop() | ||||
| 
 | ||||
|         task = evt_loop.create_task(self.prediction_forwarder()) | ||||
| 
 | ||||
|         evt_loop.run_forever() | ||||
|      | ||||
|     async def prediction_forwarder(self): | ||||
|         logger.info("Starting prediction forwarder") | ||||
|         while True: | ||||
|             msg = await self.prediction_socket.recv_string() | ||||
|             logger.debug(f"Forward prediction message of {len(msg)} chars") | ||||
|             WebSocketPredictionHandler.write_to_clients(msg) | ||||
| 
 | ||||
| def run_ws_forwarder(config: Namespace): | ||||
|     router = WsRouter(config) | ||||
|     router.start() | ||||
							
								
								
									
										83
									
								
								trap/tracker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								trap/tracker.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,83 @@ | |||
| from argparse import Namespace | ||||
| import numpy as np | ||||
| import torch | ||||
| import zmq | ||||
| import cv2 | ||||
| 
 | ||||
| from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights | ||||
| from deep_sort_realtime.deepsort_tracker import DeepSort | ||||
| 
 | ||||
| Detection = [int, int, int, int, float, int] | ||||
| Detections = [Detection] | ||||
| 
 | ||||
| class Tracker: | ||||
|     def __init__(self, config: Namespace): | ||||
|          | ||||
|         context = zmq.Context() | ||||
|         self.frame_sock = context.socket(zmq.SUB) | ||||
|         self.frame_sock.bind(config.zmq_frame_addr) | ||||
|         self.frame_sock.setsockopt(zmq.CONFLATE, 1) # only keep latest frame | ||||
| 
 | ||||
|         # TODO: config device | ||||
|         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
| 
 | ||||
|         weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT | ||||
|         self.model = retinanet_resnet50_fpn_v2(weights=weights, score_thresh=0.35) | ||||
|         self.model.to(self.device) | ||||
|         # Put the model in inference mode | ||||
|         self.model.eval() | ||||
|         # Get the transforms for the model's weights | ||||
|         self.preprocess = weights.transforms().to(self.device) | ||||
| 
 | ||||
|         self.mot_tracker = DeepSort(max_age=5) | ||||
|          | ||||
| 
 | ||||
|     def track(self): | ||||
|         while True: | ||||
|             frame = self.frame_sock.recv() | ||||
|             detections = self.detect_persons(frame) | ||||
|             tracks = self.mot_tracker.update_tracks(detections, frame=frame) | ||||
| 
 | ||||
|             # TODO: provide a track object that actually keeps history (unlike tracker) | ||||
| 
 | ||||
| 
 | ||||
|     def detect_persons(self, frame) -> Detections: | ||||
|         t = torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | ||||
|         # change axes of image loaded image to be compatilbe with torch.io.read_image (which has C,W,H format instead of W,H,C) | ||||
|         t = t.permute(2, 0, 1) | ||||
| 
 | ||||
|         batch = self.preprocess(t)[None, :].to(self.device) | ||||
|         # no_grad can be used on inference, should be slightly faster | ||||
|         with torch.no_grad(): | ||||
|             predictions = self.model(batch) | ||||
|         prediction = predictions[0] # we feed only one frame at once | ||||
| 
 | ||||
|         # TODO: check if we need e.g. cyclist | ||||
|         mask = prediction['labels'] == 1 # if we want more than one label: np.isin(prediction['labels'], [1,86]) | ||||
| 
 | ||||
|         scores = prediction['scores'][mask] | ||||
|         labels = prediction['labels'][mask] | ||||
|         boxes = prediction['boxes'][mask] | ||||
|          | ||||
|         # TODO: introduce confidence and NMS supression: https://github.com/cfotache/pytorch_objectdetecttrack/blob/master/PyTorch_Object_Tracking.ipynb | ||||
|         # (which I _think_ we better do after filtering) | ||||
|         # alternatively look at Soft-NMS https://towardsdatascience.com/non-maximum-suppression-nms-93ce178e177c | ||||
| 
 | ||||
|         #  dets - a numpy array of detections in the format [[x1,y1,x2,y2,score, label],[x1,y1,x2,y2,score, label],...] | ||||
|         detections = np.array([np.append(bbox, [score, label]) for bbox, score, label in zip(boxes.cpu(), scores.cpu(), labels.cpu())]) | ||||
|         detections = self.detect_persons_deepsort_wrapper(detections) | ||||
|          | ||||
|         return detections | ||||
|      | ||||
|     @classmethod | ||||
|     def detect_persons_deepsort_wrapper(detections): | ||||
|         """make detect_persons() compatible with | ||||
|         deep_sort_realtime tracker by going from ltrb to ltwh and | ||||
|         different nesting | ||||
|         """ | ||||
|         return [([d[0], d[1], d[2]-d[0], d[3]-d[1]], d[4], d[5]) for d in detections] | ||||
| 
 | ||||
| 
 | ||||
| def run_tracker(config: Namespace): | ||||
|     router = Tracker(config) | ||||
|     router.track() | ||||
							
								
								
									
										197
									
								
								trap/web/index.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										197
									
								
								trap/web/index.html
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,197 @@ | |||
| <!DOCTYPE html> | ||||
| <html lang="en"> | ||||
| 
 | ||||
| <head> | ||||
|     <meta charset="UTF-8"> | ||||
|     <meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||||
|     <title>Trajectory Prediction Browser Test</title> | ||||
|     <style> | ||||
|         body { | ||||
|             background: black; | ||||
|         } | ||||
| 
 | ||||
|         #field { | ||||
|             background: white; | ||||
|             width: 100%; | ||||
|             height: 100%; | ||||
|         } | ||||
|     </style> | ||||
| </head> | ||||
| 
 | ||||
| <body> | ||||
|     <canvas id="field" width="1500" height="1500"> | ||||
| 
 | ||||
|     </canvas> | ||||
| 
 | ||||
|     <script> | ||||
|         // minified https://github.com/joewalnes/reconnecting-websocket | ||||
|         !function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a}); | ||||
|     </script> | ||||
| 
 | ||||
|     <script> | ||||
|         // map the field to coordinates of our dummy tracker | ||||
|         const field_range = { x: [-10, 10], y: [-10, 10] } | ||||
| 
 | ||||
|         // Create WebSocket connection. | ||||
|         const trajectory_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/trajectory`); | ||||
|         const prediction_socket = new WebSocket(`ws://${window.location.hostname}:{{ ws_port }}/ws/prediction`); | ||||
|         let is_moving = false; | ||||
|         const fieldEl = document.getElementById('field'); | ||||
| 
 | ||||
|         let current_data = {} | ||||
|         // Listen for messages | ||||
|         prediction_socket.addEventListener("message", (event) => { | ||||
|             // console.log("Message from server ", event.data); | ||||
|             current_data = JSON.parse(event.data); | ||||
|         }); | ||||
|         prediction_socket.addEventListener("open", (e) => appendAndSendPositions()); | ||||
| 
 | ||||
|         function getMousePos(canvas, evt) { | ||||
|             const rect = canvas.getBoundingClientRect(); | ||||
|             return { | ||||
|                 x: evt.clientX - rect.left, | ||||
|                 y: evt.clientY - rect.top | ||||
|             }; | ||||
|         } | ||||
|         function mouse_coordinates_to_position(coordinates) { | ||||
|             const x_range = field_range.x[1] - field_range.x[0] | ||||
|             const x = (coordinates.x / fieldEl.clientWidth) * x_range + field_range.x[0] | ||||
|             const y_range = field_range.y[1] - field_range.y[0] | ||||
|             const y = (coordinates.y / fieldEl.clientWidth) * y_range + field_range.y[0] | ||||
|             return { x: x, y: y } | ||||
|         } | ||||
|         function position_to_canvas_coordinate(position) { | ||||
|             const x_range = field_range.x[1] - field_range.x[0] | ||||
|             const y_range = field_range.y[1] - field_range.y[0] | ||||
|              | ||||
|             const x = Array.isArray(position) ? position[0] : position.x; | ||||
|             const y = Array.isArray(position) ? position[1] : position.y; | ||||
|             return { | ||||
|                 x: (x - field_range.x[0]) * fieldEl.width / x_range, | ||||
|                 y: (y - field_range.y[0]) * fieldEl.width / y_range, | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // helper function so we can spread | ||||
|         function coord_as_list(coord) { | ||||
|             return [coord.x, coord.y] | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         let tracker = {} | ||||
|         let person_counter = 0 | ||||
| 
 | ||||
|         class Person { | ||||
|             constructor(id) { | ||||
|                 this.id = id; | ||||
|                 this.history = []; | ||||
|                 this.prediction = [] | ||||
|             } | ||||
| 
 | ||||
|             addToHistory(position) { | ||||
|                 this.history.push(position); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         let current_pos = null; | ||||
|          | ||||
| 
 | ||||
|         function appendAndSendPositions(){ | ||||
|             if(is_moving && current_pos!==null){ | ||||
|                 // throttled update of tracker on movement | ||||
|                 tracker[person_counter].addToHistory(current_pos); | ||||
|             } | ||||
| 
 | ||||
|             for(const person_id in tracker){ | ||||
|                 if(person_id != person_counter){ // compare int/str | ||||
|                     // fade out old tracks | ||||
|                     tracker[person_id].history.shift() | ||||
|                     if(!tracker[person_id].history.length){ | ||||
|                         delete tracker[person_id] | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|             console.log(tracker) | ||||
|             trajectory_socket.send(JSON.stringify(tracker)) | ||||
| 
 | ||||
|             setTimeout(appendAndSendPositions, 200) | ||||
|         } | ||||
| 
 | ||||
|         fieldEl.addEventListener('mousedown', (event) => { | ||||
|             tracker[person_counter] = new Person(person_counter); | ||||
|             is_moving = true; | ||||
| 
 | ||||
|             const mousePos = getMousePos(fieldEl, event); | ||||
|             const position = mouse_coordinates_to_position(mousePos) | ||||
|             current_pos = position; | ||||
|             // tracker[person_counter].addToHistory(current_pos); | ||||
|             // trajectory_socket.send(JSON.stringify(tracker)) | ||||
|              | ||||
|         }); | ||||
|         fieldEl.addEventListener('mousemove', (event) => { | ||||
|             if (!is_moving) return; | ||||
|             const mousePos = getMousePos(fieldEl, event); | ||||
|             const position = mouse_coordinates_to_position(mousePos) | ||||
|             current_pos = position; | ||||
|             // tracker[person_counter].addToHistory(current_pos); | ||||
|             // trajectory_socket.send(JSON.stringify(tracker)) | ||||
|         }); | ||||
|         document.addEventListener('mouseup', (e) => { | ||||
|             person_counter++; | ||||
|             is_moving = false; | ||||
|         }) | ||||
| 
 | ||||
|         const ctx = fieldEl.getContext("2d"); | ||||
|         function drawFrame() { | ||||
|             ctx.clearRect(0, 0, fieldEl.width, fieldEl.height); | ||||
|             ctx.save(); | ||||
| 
 | ||||
|             for (let id in current_data) { | ||||
|                 const person = current_data[id]; | ||||
|                 if (person.history.length > 1) { | ||||
|                     const hist = structuredClone(person.history) | ||||
|                     // draw current position: | ||||
|                     ctx.beginPath() | ||||
|                     ctx.arc( | ||||
|                         ...coord_as_list(position_to_canvas_coordinate(hist[hist.length - 1])), | ||||
|                         5, //radius | ||||
|                         0, 2 * Math.PI); | ||||
|                     ctx.fill() | ||||
|                      | ||||
|                     ctx.beginPath() | ||||
|                     ctx.lineWidth = 3; | ||||
|                     ctx.strokeStyle = "#325FA2"; | ||||
| 
 | ||||
|                     ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(hist.shift()))); | ||||
|                     for (const position of hist) { | ||||
|                         ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position))) | ||||
|                     } | ||||
|                     ctx.stroke(); | ||||
|                 } | ||||
| 
 | ||||
|                 if(person.hasOwnProperty('predictions') && person.predictions.length > 0) { | ||||
|                     // multiple predictions can be sampled | ||||
|                     person.predictions.forEach((prediction, i) => { | ||||
|                         ctx.beginPath() | ||||
|                         ctx.lineWidth = i === 1 ? 3 : 0.2; | ||||
|                         ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa"; | ||||
|                          | ||||
|                         // start from current position: | ||||
|                         ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1]))); | ||||
|                         for (const position of prediction) { | ||||
|                             ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position))) | ||||
|                         } | ||||
|                         ctx.stroke(); | ||||
|                     }); | ||||
|                 } | ||||
|             } | ||||
|             ctx.restore(); | ||||
| 
 | ||||
|             window.requestAnimationFrame(drawFrame); | ||||
|         } | ||||
| 
 | ||||
|         window.requestAnimationFrame(drawFrame); | ||||
|     </script> | ||||
| </body> | ||||
| 
 | ||||
| </html> | ||||
		Loading…
	
		Reference in a new issue
	
	 Ruben van de Ven
						Ruben van de Ven