diff --git a/trajpred/plumber.py b/trajpred/plumber.py
index 1c51445..2748e17 100644
--- a/trajpred/plumber.py
+++ b/trajpred/plumber.py
@@ -13,12 +13,6 @@ def start():
logging.basicConfig(
level=loglevel,
)
- # rootLogger = logging.getLogger()
- # rootLogger.setLevel(loglevel)
-
- movement_q = Queue()
- prediction_q = Queue()
-
# instantiating process with arguments
procs = [
@@ -26,7 +20,7 @@ def start():
]
if not args.bypass_prediction:
procs.append(
- Process(target=run_inference_server, args=(args, movement_q, prediction_q)),
+ Process(target=run_inference_server, args=(args,)),
)
logger.info("start")
diff --git a/trajpred/prediction_server.py b/trajpred/prediction_server.py
index b918557..e792575 100644
--- a/trajpred/prediction_server.py
+++ b/trajpred/prediction_server.py
@@ -9,13 +9,13 @@ import dill
import random
import pathlib
import numpy as np
-import trajectron.visualization as vis
+from trajectron.utils import prediction_output_to_trajectories
from trajectron.model.online.online_trajectron import OnlineTrajectron
from trajectron.model.model_registrar import ModelRegistrar
from trajectron.environment import Environment, Scene
import matplotlib.pyplot as plt
-
+import zmq
logger = logging.getLogger("trajpred.inference")
@@ -102,10 +102,17 @@ def get_maps_for_input(input_dict, scene, hyperparams):
class InferenceServer:
- def __init__(self, config: dict, movement_q: Queue, prediction_q: Queue):
+ def __init__(self, config: dict):
self.config = config
- self.movement_q = movement_q
- self.prediction_q = prediction_q
+
+ context = zmq.Context()
+ self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
+ self.trajectory_socket.connect(config.zmq_trajectory_addr)
+ self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
+
+ self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
+ self.prediction_socket.bind(config.zmq_prediction_addr)
+ print(self.prediction_socket)
def run(self):
@@ -153,7 +160,7 @@ class InferenceServer:
for scene in eval_env.scenes:
scene.add_robot_from_nodes(eval_env.robot_type)
- print('Loaded data from %s' % (self.config.eval_data_dict,))
+ logger.info('Loaded data from %s' % (self.config.eval_data_dict,))
# Creating a dummy environment with a single scene that contains information about the world.
# When using this code, feel free to use whichever scene index or initial timestep you wish.
@@ -204,53 +211,60 @@ class InferenceServer:
robot_present_and_future=robot_present_and_future,
full_dist=True)
end = time.time()
- print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
+ logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
1. / (end - start), len(trajectron.nodes),
trajectron.scene_graph.get_num_edges()))
- detailed_preds_dict = dict()
- for node in eval_scene.nodes:
- if node in preds:
- detailed_preds_dict[node] = preds[node]
+ # unsure what this bit from online_prediction.py does:
+ # detailed_preds_dict = dict()
+ # for node in eval_scene.nodes:
+ # if node in preds:
+ # detailed_preds_dict[node] = preds[node]
- fig = plt.figure(figsize=(10,10))
- ax = fig.gca()
- # fig, ax = plt.subplots()
+ #adapted from trajectron.visualization
+ # prediction_dict provides the actual predictions
+ # histories_dict provides the trajectory used for prediction
+ # futures_dict is the Ground Truth, which is unvailable in an online setting
+ prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
+ eval_scene.dt,
+ hyperparams['maximum_history_length'],
+ hyperparams['prediction_horizon']
+ )
+
+ assert(len(prediction_dict.keys()) <= 1)
+ if len(prediction_dict.keys()) == 0:
+ return
+ ts_key = list(prediction_dict.keys())[0]
+
+ prediction_dict = prediction_dict[ts_key]
+ histories_dict = histories_dict[ts_key]
+ futures_dict = futures_dict[ts_key]
+
+ response = {}
- # vis.visualize_distribution(ax,
- # dists)
- vis.visualize_prediction(ax,
- {timestep: preds},
- eval_scene.dt,
- hyperparams['maximum_history_length'],
- hyperparams['prediction_horizon'])
+ for node in histories_dict:
+ history = histories_dict[node]
+ # future = futures_dict[node]
+ predictions = prediction_dict[node]
- if eval_scene.robot is not None and hyperparams['incl_robot_node']:
- robot_for_plotting = eval_scene.robot.get(np.array([timestep,
- timestep + hyperparams['prediction_horizon']]),
- hyperparams['state'][eval_scene.robot.type])
- # robot_for_plotting += adjustment
+ if np.isnan(history[-1]).any():
+ continue
- ax.plot(robot_for_plotting[1:, 1], robot_for_plotting[1:, 0],
- color='r',
- linewidth=1.0, alpha=1.0)
+ response[node.id] = {
+ 'id': node.id,
+ 'history': history.tolist(),
+ 'predictions': predictions[0].tolist() # use batch 0
+ }
- # Current Node Position
- circle = plt.Circle((robot_for_plotting[0, 1],
- robot_for_plotting[0, 0]),
- 0.3,
- facecolor='r',
- edgecolor='k',
- lw=0.5,
- zorder=3)
- ax.add_artist(circle)
+ data = json.dumps(response)
+ self.prediction_socket.send_string(data)
+ # time.sleep(1)
+ # print(prediction_dict)
+ # print(histories_dict)
+ # print(futures_dict)
- ax.set_xlim(-10,10)
- ax.set_ylim(-10,10)
- fig.suptitle(f"frame {timestep:04d}")
- fig.savefig(os.path.join(output_save_dir, f'pred_{timestep:04d}.png'))
- plt.close(fig)
+
-def run_inference_server(config, movement_q: Queue, prediction_q: Queue):
- s = InferenceServer(config, movement_q, prediction_q)
+def run_inference_server(config):
+ s = InferenceServer(config)
s.run()
\ No newline at end of file
diff --git a/trajpred/socket_forwarder.py b/trajpred/socket_forwarder.py
index ce81127..ff5fb74 100644
--- a/trajpred/socket_forwarder.py
+++ b/trajpred/socket_forwarder.py
@@ -108,7 +108,7 @@ class WsRouter:
context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB)
- self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajection_addr)
+ self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.connect(config.zmq_prediction_addr)
@@ -154,7 +154,7 @@ class WsRouter:
logger.info("Starting prediction forwarder")
while True:
msg = await self.prediction_socket.recv_string()
- logger.info("Forward: ")
+ logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg)
def run_ws_forwarder(config: Namespace):
diff --git a/trajpred/web/index.html b/trajpred/web/index.html
index 24db058..a2b589e 100644
--- a/trajpred/web/index.html
+++ b/trajpred/web/index.html
@@ -23,6 +23,11 @@
+
+