From ee57604b309608379f64f875cd4606699d0fe542 Mon Sep 17 00:00:00 2001 From: Ruben van de Ven Date: Wed, 11 Oct 2023 16:35:15 +0200 Subject: [PATCH] Render predictions to browser --- trajpred/plumber.py | 8 +-- trajpred/prediction_server.py | 104 +++++++++++++++++++--------------- trajpred/socket_forwarder.py | 4 +- trajpred/web/index.html | 31 +++++++++- 4 files changed, 90 insertions(+), 57 deletions(-) diff --git a/trajpred/plumber.py b/trajpred/plumber.py index 1c51445..2748e17 100644 --- a/trajpred/plumber.py +++ b/trajpred/plumber.py @@ -13,12 +13,6 @@ def start(): logging.basicConfig( level=loglevel, ) - # rootLogger = logging.getLogger() - # rootLogger.setLevel(loglevel) - - movement_q = Queue() - prediction_q = Queue() - # instantiating process with arguments procs = [ @@ -26,7 +20,7 @@ def start(): ] if not args.bypass_prediction: procs.append( - Process(target=run_inference_server, args=(args, movement_q, prediction_q)), + Process(target=run_inference_server, args=(args,)), ) logger.info("start") diff --git a/trajpred/prediction_server.py b/trajpred/prediction_server.py index b918557..e792575 100644 --- a/trajpred/prediction_server.py +++ b/trajpred/prediction_server.py @@ -9,13 +9,13 @@ import dill import random import pathlib import numpy as np -import trajectron.visualization as vis +from trajectron.utils import prediction_output_to_trajectories from trajectron.model.online.online_trajectron import OnlineTrajectron from trajectron.model.model_registrar import ModelRegistrar from trajectron.environment import Environment, Scene import matplotlib.pyplot as plt - +import zmq logger = logging.getLogger("trajpred.inference") @@ -102,10 +102,17 @@ def get_maps_for_input(input_dict, scene, hyperparams): class InferenceServer: - def __init__(self, config: dict, movement_q: Queue, prediction_q: Queue): + def __init__(self, config: dict): self.config = config - self.movement_q = movement_q - self.prediction_q = prediction_q + + context = zmq.Context() + self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB) + self.trajectory_socket.connect(config.zmq_trajectory_addr) + self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'') + + self.prediction_socket: zmq.Socket = context.socket(zmq.PUB) + self.prediction_socket.bind(config.zmq_prediction_addr) + print(self.prediction_socket) def run(self): @@ -153,7 +160,7 @@ class InferenceServer: for scene in eval_env.scenes: scene.add_robot_from_nodes(eval_env.robot_type) - print('Loaded data from %s' % (self.config.eval_data_dict,)) + logger.info('Loaded data from %s' % (self.config.eval_data_dict,)) # Creating a dummy environment with a single scene that contains information about the world. # When using this code, feel free to use whichever scene index or initial timestep you wish. @@ -204,53 +211,60 @@ class InferenceServer: robot_present_and_future=robot_present_and_future, full_dist=True) end = time.time() - print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start, + logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start, 1. / (end - start), len(trajectron.nodes), trajectron.scene_graph.get_num_edges())) - detailed_preds_dict = dict() - for node in eval_scene.nodes: - if node in preds: - detailed_preds_dict[node] = preds[node] + # unsure what this bit from online_prediction.py does: + # detailed_preds_dict = dict() + # for node in eval_scene.nodes: + # if node in preds: + # detailed_preds_dict[node] = preds[node] - fig = plt.figure(figsize=(10,10)) - ax = fig.gca() - # fig, ax = plt.subplots() + #adapted from trajectron.visualization + # prediction_dict provides the actual predictions + # histories_dict provides the trajectory used for prediction + # futures_dict is the Ground Truth, which is unvailable in an online setting + prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds}, + eval_scene.dt, + hyperparams['maximum_history_length'], + hyperparams['prediction_horizon'] + ) + + assert(len(prediction_dict.keys()) <= 1) + if len(prediction_dict.keys()) == 0: + return + ts_key = list(prediction_dict.keys())[0] + + prediction_dict = prediction_dict[ts_key] + histories_dict = histories_dict[ts_key] + futures_dict = futures_dict[ts_key] + + response = {} - # vis.visualize_distribution(ax, - # dists) - vis.visualize_prediction(ax, - {timestep: preds}, - eval_scene.dt, - hyperparams['maximum_history_length'], - hyperparams['prediction_horizon']) + for node in histories_dict: + history = histories_dict[node] + # future = futures_dict[node] + predictions = prediction_dict[node] - if eval_scene.robot is not None and hyperparams['incl_robot_node']: - robot_for_plotting = eval_scene.robot.get(np.array([timestep, - timestep + hyperparams['prediction_horizon']]), - hyperparams['state'][eval_scene.robot.type]) - # robot_for_plotting += adjustment + if np.isnan(history[-1]).any(): + continue - ax.plot(robot_for_plotting[1:, 1], robot_for_plotting[1:, 0], - color='r', - linewidth=1.0, alpha=1.0) + response[node.id] = { + 'id': node.id, + 'history': history.tolist(), + 'predictions': predictions[0].tolist() # use batch 0 + } - # Current Node Position - circle = plt.Circle((robot_for_plotting[0, 1], - robot_for_plotting[0, 0]), - 0.3, - facecolor='r', - edgecolor='k', - lw=0.5, - zorder=3) - ax.add_artist(circle) + data = json.dumps(response) + self.prediction_socket.send_string(data) + # time.sleep(1) + # print(prediction_dict) + # print(histories_dict) + # print(futures_dict) - ax.set_xlim(-10,10) - ax.set_ylim(-10,10) - fig.suptitle(f"frame {timestep:04d}") - fig.savefig(os.path.join(output_save_dir, f'pred_{timestep:04d}.png')) - plt.close(fig) + -def run_inference_server(config, movement_q: Queue, prediction_q: Queue): - s = InferenceServer(config, movement_q, prediction_q) +def run_inference_server(config): + s = InferenceServer(config) s.run() \ No newline at end of file diff --git a/trajpred/socket_forwarder.py b/trajpred/socket_forwarder.py index ce81127..ff5fb74 100644 --- a/trajpred/socket_forwarder.py +++ b/trajpred/socket_forwarder.py @@ -108,7 +108,7 @@ class WsRouter: context = zmq.asyncio.Context() self.trajectory_socket = context.socket(zmq.PUB) - self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajection_addr) + self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr) self.prediction_socket = context.socket(zmq.SUB) self.prediction_socket.connect(config.zmq_prediction_addr) @@ -154,7 +154,7 @@ class WsRouter: logger.info("Starting prediction forwarder") while True: msg = await self.prediction_socket.recv_string() - logger.info("Forward: ") + logger.debug(f"Forward prediction message of {len(msg)} chars") WebSocketPredictionHandler.write_to_clients(msg) def run_ws_forwarder(config: Namespace): diff --git a/trajpred/web/index.html b/trajpred/web/index.html index 24db058..a2b589e 100644 --- a/trajpred/web/index.html +++ b/trajpred/web/index.html @@ -23,6 +23,11 @@ + +