From 9ba283ca9b630162d17c7585069bff022f9be628 Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Wed, 11 Oct 2023 13:58:09 +0200 Subject: [PATCH] Trajectory prediction - test in browser --- run_server.py | 5 + trajpred/__init__.py | 0 trajpred/config.py | 142 +++++++++++++++++++ trajpred/plumber.py | 38 +++++ trajpred/prediction_server.py | 256 ++++++++++++++++++++++++++++++++++ trajpred/socket_forwarder.py | 162 +++++++++++++++++++++ trajpred/web/index.html | 145 +++++++++++++++++++ 7 files changed, 748 insertions(+) create mode 100644 run_server.py create mode 100644 trajpred/__init__.py create mode 100644 trajpred/config.py create mode 100644 trajpred/plumber.py create mode 100644 trajpred/prediction_server.py create mode 100644 trajpred/socket_forwarder.py create mode 100644 trajpred/web/index.html diff --git a/run_server.py b/run_server.py new file mode 100644 index 0000000..4f5c843 --- /dev/null +++ b/run_server.py @@ -0,0 +1,5 @@ +from trajpred import plumber + +if __name__ == "__main__": + plumber.start() + diff --git a/trajpred/__init__.py b/trajpred/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trajpred/config.py b/trajpred/config.py new file mode 100644 index 0000000..ad605f0 --- /dev/null +++ b/trajpred/config.py @@ -0,0 +1,142 @@ +import argparse + +parser = argparse.ArgumentParser() + + +parser.add_argument( + '--verbose', + '-v', + help="Increase verbosity. Add multiple times to increase further.", + action='count', default=0 +) + +# parser.add_argument('--foo') +inference_parser = parser.add_argument_group('inference server') +connection_parser = parser.add_argument_group('connection') + +inference_parser.add_argument("--model_dir", + help="directory with the model to use for inference", + type=str, # TODO: make into Path + default='./experiments/pedestrians/models/models_04_Oct_2023_21_04_48_eth_vel_ar3') + +inference_parser.add_argument("--conf", + help="path to json config file for hyperparameters, relative to model_dir", + type=str, + default='config.json') + +# Model Parameters (hyperparameters) +inference_parser.add_argument("--offline_scene_graph", + help="whether to precompute the scene graphs offline, options are 'no' and 'yes'", + type=str, + default='yes') + +inference_parser.add_argument("--dynamic_edges", + help="whether to use dynamic edges or not, options are 'no' and 'yes'", + type=str, + default='yes') + +inference_parser.add_argument("--edge_state_combine_method", + help="the method to use for combining edges of the same type", + type=str, + default='sum') + +inference_parser.add_argument("--edge_influence_combine_method", + help="the method to use for combining edge influences", + type=str, + default='attention') + +inference_parser.add_argument('--edge_addition_filter', + nargs='+', + help="what scaling to use for edges as they're created", + type=float, + default=[0.25, 0.5, 0.75, 1.0]) # We don't automatically pad left with 0.0, if you want a sharp + # and short edge addition, then you need to have a 0.0 at the + # beginning, e.g. [0.0, 1.0]. + +inference_parser.add_argument('--edge_removal_filter', + nargs='+', + help="what scaling to use for edges as they're removed", + type=float, + default=[1.0, 0.0]) # We don't automatically pad right with 0.0, if you want a sharp drop off like + # the default, then you need to have a 0.0 at the end. + + +inference_parser.add_argument('--incl_robot_node', + help="whether to include a robot node in the graph or simply model all agents", + action='store_true') + +inference_parser.add_argument('--map_encoding', + help="Whether to use map encoding or not", + action='store_true') + +inference_parser.add_argument('--no_edge_encoding', + help="Whether to use neighbors edge encoding", + action='store_true') + + +inference_parser.add_argument('--batch_size', + help='training batch size', + type=int, + default=256) + +inference_parser.add_argument('--k_eval', + help='how many samples to take during evaluation', + type=int, + default=25) + +# Data Parameters +inference_parser.add_argument("--eval_data_dict", + help="what file to load for evaluation data (WHEN NOT USING LIVE DATA)", + type=str, + default='./experiments/processed/eth_test.pkl') + +inference_parser.add_argument("--output_dir", + help="what dir to save output (i.e., saved models, logs, etc) (WHEN NOT USING LIVE OUTPUT)", + type=str, + default='../experiments/pedestrians/OUT/test_inference') + + +# inference_parser.add_argument('--device', +# help='what device to perform training on', +# type=str, +# default='cuda:0') + +inference_parser.add_argument("--eval_device", + help="what device to use during inference", + type=str, + default="cpu") + + +inference_parser.add_argument('--seed', + help='manual seed to use, default is 123', + type=int, + default=123) + + +# Internal connections. + +connection_parser.add_argument('--zmq-trajectory-addr', + help='Manually specity communication addr for the trajectory messages', + type=str, + default="ipc:///tmp/feeds/traj") + +connection_parser.add_argument('--zmq-camera-stream-addr', + help='Manually specity communication addr for the camera stream messages', + type=str, + default="ipc:///tmp/feeds/img") + +connection_parser.add_argument('--zmq-prediction-addr', + help='Manually specity communication addr for the prediction messages', + type=str, + default="ipc:///tmp/feeds/preds") + + +connection_parser.add_argument('--ws-port', + help='Port to listen for incomming websocket connections. Also serves the testing html-page.', + type=int, + default=8888) + +connection_parser.add_argument('--bypass-prediction', + help='For debugging purpose: websocket input immediately to output', + action='store_true') + diff --git a/trajpred/plumber.py b/trajpred/plumber.py new file mode 100644 index 0000000..1c51445 --- /dev/null +++ b/trajpred/plumber.py @@ -0,0 +1,38 @@ +import logging +from multiprocessing import Process, Queue +from trajpred.config import parser +from trajpred.prediction_server import InferenceServer, run_inference_server +from trajpred.socket_forwarder import run_ws_forwarder + + +logger = logging.getLogger("trajpred.plumbing") + +def start(): + args = parser.parse_args() + loglevel = logging.NOTSET if args.verbose > 1 else logging.DEBUG if args.verbose > 0 else logging.INFO + logging.basicConfig( + level=loglevel, + ) + # rootLogger = logging.getLogger() + # rootLogger.setLevel(loglevel) + + movement_q = Queue() + prediction_q = Queue() + + + # instantiating process with arguments + procs = [ + Process(target=run_ws_forwarder, args=(args,)) + ] + if not args.bypass_prediction: + procs.append( + Process(target=run_inference_server, args=(args, movement_q, prediction_q)), + ) + + logger.info("start") + for proc in procs: + proc.start() + + for proc in procs: + proc.join() + diff --git a/trajpred/prediction_server.py b/trajpred/prediction_server.py new file mode 100644 index 0000000..b918557 --- /dev/null +++ b/trajpred/prediction_server.py @@ -0,0 +1,256 @@ +# adapted from Trajectron++ online_server.py +import logging +from multiprocessing import Queue +import os +import time +import json +import torch +import dill +import random +import pathlib +import numpy as np +import trajectron.visualization as vis +from trajectron.model.online.online_trajectron import OnlineTrajectron +from trajectron.model.model_registrar import ModelRegistrar +from trajectron.environment import Environment, Scene +import matplotlib.pyplot as plt + + + +logger = logging.getLogger("trajpred.inference") + + +# if not torch.cuda.is_available() or self.config.device == 'cpu': +# self.config.device = torch.device('cpu') +# else: +# if torch.cuda.device_count() == 1: +# # If you have CUDA_VISIBLE_DEVICES set, which you should, +# # then this will prevent leftover flag arguments from +# # messing with the device allocation. +# self.config.device = 'cuda:0' + +# self.config.device = torch.device(self.config.device) + + + +def create_online_env(env, hyperparams, scene_idx, init_timestep): + test_scene = env.scenes[scene_idx] + + online_scene = Scene(timesteps=init_timestep + 1, + map=test_scene.map, + dt=test_scene.dt) + online_scene.nodes = test_scene.get_nodes_clipped_at_time( + timesteps=np.arange(init_timestep - hyperparams['maximum_history_length'], + init_timestep + 1), + state=hyperparams['state']) + online_scene.robot = test_scene.robot + online_scene.calculate_scene_graph(attention_radius=env.attention_radius, + edge_addition_filter=hyperparams['edge_addition_filter'], + edge_removal_filter=hyperparams['edge_removal_filter']) + + return Environment(node_type_list=env.node_type_list, + standardization=env.standardization, + scenes=[online_scene], + attention_radius=env.attention_radius, + robot_type=env.robot_type) + + +def get_maps_for_input(input_dict, scene, hyperparams): + scene_maps = list() + scene_pts = list() + heading_angles = list() + patch_sizes = list() + nodes_with_maps = list() + for node in input_dict: + if node.type in hyperparams['map_encoder']: + x = input_dict[node] + me_hyp = hyperparams['map_encoder'][node.type] + if 'heading_state_index' in me_hyp: + heading_state_index = me_hyp['heading_state_index'] + # We have to rotate the map in the opposit direction of the agent to match them + if type(heading_state_index) is list: # infer from velocity or heading vector + heading_angle = -np.arctan2(x[-1, heading_state_index[1]], + x[-1, heading_state_index[0]]) * 180 / np.pi + else: + heading_angle = -x[-1, heading_state_index] * 180 / np.pi + else: + heading_angle = None + + scene_map = scene.map[node.type] + map_point = x[-1, :2] + + patch_size = hyperparams['map_encoder'][node.type]['patch_size'] + + scene_maps.append(scene_map) + scene_pts.append(map_point) + heading_angles.append(heading_angle) + patch_sizes.append(patch_size) + nodes_with_maps.append(node) + + if heading_angles[0] is None: + heading_angles = None + else: + heading_angles = torch.Tensor(heading_angles) + + maps = scene_maps[0].get_cropped_maps_from_scene_map_batch(scene_maps, + scene_pts=torch.Tensor(scene_pts), + patch_size=patch_sizes[0], + rotation=heading_angles) + + maps_dict = {node: maps[[i]] for i, node in enumerate(nodes_with_maps)} + return maps_dict + + +class InferenceServer: + def __init__(self, config: dict, movement_q: Queue, prediction_q: Queue): + self.config = config + self.movement_q = movement_q + self.prediction_q = prediction_q + + def run(self): + + if self.config.seed is not None: + random.seed(self.config.seed) + np.random.seed(self.config.seed) + torch.manual_seed(self.config.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(self.config.seed) + + # Choose one of the model directory names under the experiment/*/models folders. + # Possibilities are 'vel_ee', 'int_ee', 'int_ee_me', or 'robot' + # model_dir = os.path.join(self.config.log_dir, 'int_ee') + # model_dir = 'models/models_04_Oct_2023_21_04_48_eth_vel_ar3' + + # Load hyperparameters from json + config_file = os.path.join(self.config.model_dir, self.config.conf) + if not os.path.exists(config_file): + raise ValueError('Config json not found!') + with open(config_file, 'r') as conf_json: + hyperparams = json.load(conf_json) + + # Add hyperparams from arguments + hyperparams['dynamic_edges'] = self.config.dynamic_edges + hyperparams['edge_state_combine_method'] = self.config.edge_state_combine_method + hyperparams['edge_influence_combine_method'] = self.config.edge_influence_combine_method + hyperparams['edge_addition_filter'] = self.config.edge_addition_filter + hyperparams['edge_removal_filter'] = self.config.edge_removal_filter + hyperparams['batch_size'] = self.config.batch_size + hyperparams['k_eval'] = self.config.k_eval + hyperparams['offline_scene_graph'] = self.config.offline_scene_graph + hyperparams['incl_robot_node'] = self.config.incl_robot_node + hyperparams['edge_encoding'] = not self.config.no_edge_encoding + hyperparams['use_map_encoding'] = self.config.map_encoding + + output_save_dir = os.path.join(self.config.output_dir, 'pred_figs') + pathlib.Path(output_save_dir).mkdir(parents=True, exist_ok=True) + + + with open(self.config.eval_data_dict, 'rb') as f: + eval_env = dill.load(f, encoding='latin1') + + if eval_env.robot_type is None and hyperparams['incl_robot_node']: + eval_env.robot_type = eval_env.NodeType[0] # TODO: Make more general, allow the user to specify? + for scene in eval_env.scenes: + scene.add_robot_from_nodes(eval_env.robot_type) + + print('Loaded data from %s' % (self.config.eval_data_dict,)) + + # Creating a dummy environment with a single scene that contains information about the world. + # When using this code, feel free to use whichever scene index or initial timestep you wish. + scene_idx = 0 + + # You need to have at least acceleration, so you want 2 timesteps of prior data, e.g. [0, 1], + # so that you can immediately start incremental inference from the 3rd timestep onwards. + init_timestep = 1 + + eval_scene = eval_env.scenes[scene_idx] + online_env = create_online_env(eval_env, hyperparams, scene_idx, init_timestep) + + model_registrar = ModelRegistrar(self.config.model_dir, self.config.eval_device) + model_registrar.load_models(iter_num=100) + + trajectron = OnlineTrajectron(model_registrar, + hyperparams, + self.config.eval_device) + + # If you want to see what different robot futures do to the predictions, uncomment this line as well as + # related "... += adjustment" lines below. + # adjustment = np.stack([np.arange(13)/float(i*2.0) for i in range(6, 12)], axis=1) + + # Here's how you'd incrementally run the model, e.g. with streaming data. + trajectron.set_environment(online_env, init_timestep) + + for timestep in range(init_timestep + 1, eval_scene.timesteps): + input_dict = eval_scene.get_clipped_input_dict(timestep, hyperparams['state']) + + maps = None + if hyperparams['use_map_encoding']: + maps = get_maps_for_input(input_dict, eval_scene, hyperparams) + + robot_present_and_future = None + if eval_scene.robot is not None and hyperparams['incl_robot_node']: + robot_present_and_future = eval_scene.robot.get(np.array([timestep, + timestep + hyperparams['prediction_horizon']]), + hyperparams['state'][eval_scene.robot.type], + padding=0.0) + robot_present_and_future = np.stack([robot_present_and_future, robot_present_and_future], axis=0) + # robot_present_and_future += adjustment + + start = time.time() + dists, preds = trajectron.incremental_forward(input_dict, + maps, + prediction_horizon=6, + num_samples=51, + robot_present_and_future=robot_present_and_future, + full_dist=True) + end = time.time() + print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start, + 1. / (end - start), len(trajectron.nodes), + trajectron.scene_graph.get_num_edges())) + + detailed_preds_dict = dict() + for node in eval_scene.nodes: + if node in preds: + detailed_preds_dict[node] = preds[node] + + fig = plt.figure(figsize=(10,10)) + ax = fig.gca() + # fig, ax = plt.subplots() + + # vis.visualize_distribution(ax, + # dists) + vis.visualize_prediction(ax, + {timestep: preds}, + eval_scene.dt, + hyperparams['maximum_history_length'], + hyperparams['prediction_horizon']) + + if eval_scene.robot is not None and hyperparams['incl_robot_node']: + robot_for_plotting = eval_scene.robot.get(np.array([timestep, + timestep + hyperparams['prediction_horizon']]), + hyperparams['state'][eval_scene.robot.type]) + # robot_for_plotting += adjustment + + ax.plot(robot_for_plotting[1:, 1], robot_for_plotting[1:, 0], + color='r', + linewidth=1.0, alpha=1.0) + + # Current Node Position + circle = plt.Circle((robot_for_plotting[0, 1], + robot_for_plotting[0, 0]), + 0.3, + facecolor='r', + edgecolor='k', + lw=0.5, + zorder=3) + ax.add_artist(circle) + + ax.set_xlim(-10,10) + ax.set_ylim(-10,10) + fig.suptitle(f"frame {timestep:04d}") + fig.savefig(os.path.join(output_save_dir, f'pred_{timestep:04d}.png')) + plt.close(fig) + +def run_inference_server(config, movement_q: Queue, prediction_q: Queue): + s = InferenceServer(config, movement_q, prediction_q) + s.run() \ No newline at end of file diff --git a/trajpred/socket_forwarder.py b/trajpred/socket_forwarder.py new file mode 100644 index 0000000..ce81127 --- /dev/null +++ b/trajpred/socket_forwarder.py @@ -0,0 +1,162 @@ + +from argparse import Namespace +import asyncio +import logging +from typing import Set, Union, Dict, Any +from typing_extensions import Self + +from urllib.error import HTTPError +import tornado.ioloop +import tornado.web +import tornado.websocket +import zmq +import zmq.asyncio + + +logger = logging.getLogger("trajpred.forwarder") + + +class WebSocketTrajectoryHandler(tornado.websocket.WebSocketHandler): + def initialize(self, zmq_socket: zmq.asyncio.Socket): + self.zmq_socket = zmq_socket + + async def on_message(self, message): + logger.info(f"recieve: {message}") + + try: + await self.zmq_socket.send_string(message) + # msg = json.loads(message) + except Exception as e: + # self.send({'alert': 'Invalid request: {}'.format(e)}) + logger.exception(e) + # self.write_message(u"You said: " + message) + + def open(self, p=None): + logger.info(f"connected {self.request.remote_ip}") + + # client disconnected + def on_close(self): + logger.info(f"Client disconnected: {self.request.remote_ip}") + + + +class WebSocketPredictionHandler(tornado.websocket.WebSocketHandler): + connections: Set[Self] = set() + + def initialize(self, config): + self.config = config + + def on_message(self, message): + logger.warning(f"Receiving message on send-only ws handler: {message}") + + def open(self, p=None): + logger.info(f"Prediction WS connected {self.request.remote_ip}") + self.__class__.connections.add(self) + + # client disconnected + def on_close(self): + self.__class__.rmConnection(self) + + logger.info(f"Client disconnected: {self.request.remote_ip}") + + @classmethod + def rmConnection(cls, client): + if client not in cls.connections: + return + cls.connections.remove(client) + + @classmethod + def hasConnection(cls, client): + return client in cls.connections + + @classmethod + def write_to_clients(cls, msg: Union[bytes, str, Dict[str, Any]]): + if msg is None: + logger.critical("Tried to send 'none'") + return + + toRemove = [] + for client in cls.connections: + try: + client.write_message(msg) + except tornado.websocket.WebSocketClosedError as e: + logger.warning(f"Not properly closed websocket connection") + toRemove.append(client) # If we remove it here from the set we get an exception about changing set size during iteration + + for client in toRemove: + cls.rmConnection(client) + +class DemoHandler(tornado.web.RequestHandler): + def initialize(self, config: Namespace): + self.config = config + + def get(self): + self.render("index.html", ws_port=self.config.ws_port) + +class StaticFileWithHeaderHandler(tornado.web.StaticFileHandler): + def set_extra_headers(self, path): + """For subclass to add extra headers to the response""" + if path[-5:] == ".html": + self.set_header("Access-Control-Allow-Origin", "*") + if path[-4:] == ".svg": + self.set_header("Content-Type", "image/svg+xml") + + +class WsRouter: + def __init__(self, config: Namespace): + self.config = config + + context = zmq.asyncio.Context() + self.trajectory_socket = context.socket(zmq.PUB) + self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajection_addr) + + self.prediction_socket = context.socket(zmq.SUB) + self.prediction_socket.connect(config.zmq_prediction_addr) + self.prediction_socket.setsockopt(zmq.SUBSCRIBE, b'') + + self.application = tornado.web.Application( + [ + ( + r"/ws/trajectory", + WebSocketTrajectoryHandler, + { + "zmq_socket": self.trajectory_socket + }, + ), + ( + r"/ws/prediction", + WebSocketPredictionHandler, + { + "config": config, + }, + ), + (r"/", DemoHandler, {"config": config}), + # (r"/(.*)", StaticFileWithHeaderHandler, {"config": config, "index": 'index.html'}), + ], + template_path = 'trajpred/web/', + compiled_template_cache=False) + + def start(self): + + evt_loop = asyncio.new_event_loop() + asyncio.set_event_loop(evt_loop) + + # loop = tornado.ioloop.IOLoop.current() + logger.info(f"Listen on {self.config.ws_port}") + self.application.listen(self.config.ws_port) + loop = asyncio.get_event_loop() + + task = evt_loop.create_task(self.prediction_forwarder()) + + evt_loop.run_forever() + + async def prediction_forwarder(self): + logger.info("Starting prediction forwarder") + while True: + msg = await self.prediction_socket.recv_string() + logger.info("Forward: ") + WebSocketPredictionHandler.write_to_clients(msg) + +def run_ws_forwarder(config: Namespace): + router = WsRouter(config) + router.start() \ No newline at end of file diff --git a/trajpred/web/index.html b/trajpred/web/index.html new file mode 100644 index 0000000..24db058 --- /dev/null +++ b/trajpred/web/index.html @@ -0,0 +1,145 @@ + + + + + + + Trajectory Prediction Browser Test + + + + + + + + + + + + \ No newline at end of file