Render predictions to browser

This commit is contained in:
Ruben van de Ven 2023-10-11 16:35:15 +02:00
parent 9ba283ca9b
commit ee57604b30
4 changed files with 90 additions and 57 deletions

View file

@ -13,12 +13,6 @@ def start():
logging.basicConfig( logging.basicConfig(
level=loglevel, level=loglevel,
) )
# rootLogger = logging.getLogger()
# rootLogger.setLevel(loglevel)
movement_q = Queue()
prediction_q = Queue()
# instantiating process with arguments # instantiating process with arguments
procs = [ procs = [
@ -26,7 +20,7 @@ def start():
] ]
if not args.bypass_prediction: if not args.bypass_prediction:
procs.append( procs.append(
Process(target=run_inference_server, args=(args, movement_q, prediction_q)), Process(target=run_inference_server, args=(args,)),
) )
logger.info("start") logger.info("start")

View file

@ -9,13 +9,13 @@ import dill
import random import random
import pathlib import pathlib
import numpy as np import numpy as np
import trajectron.visualization as vis from trajectron.utils import prediction_output_to_trajectories
from trajectron.model.online.online_trajectron import OnlineTrajectron from trajectron.model.online.online_trajectron import OnlineTrajectron
from trajectron.model.model_registrar import ModelRegistrar from trajectron.model.model_registrar import ModelRegistrar
from trajectron.environment import Environment, Scene from trajectron.environment import Environment, Scene
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import zmq
logger = logging.getLogger("trajpred.inference") logger = logging.getLogger("trajpred.inference")
@ -102,10 +102,17 @@ def get_maps_for_input(input_dict, scene, hyperparams):
class InferenceServer: class InferenceServer:
def __init__(self, config: dict, movement_q: Queue, prediction_q: Queue): def __init__(self, config: dict):
self.config = config self.config = config
self.movement_q = movement_q
self.prediction_q = prediction_q context = zmq.Context()
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
self.trajectory_socket.connect(config.zmq_trajectory_addr)
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
self.prediction_socket.bind(config.zmq_prediction_addr)
print(self.prediction_socket)
def run(self): def run(self):
@ -153,7 +160,7 @@ class InferenceServer:
for scene in eval_env.scenes: for scene in eval_env.scenes:
scene.add_robot_from_nodes(eval_env.robot_type) scene.add_robot_from_nodes(eval_env.robot_type)
print('Loaded data from %s' % (self.config.eval_data_dict,)) logger.info('Loaded data from %s' % (self.config.eval_data_dict,))
# Creating a dummy environment with a single scene that contains information about the world. # Creating a dummy environment with a single scene that contains information about the world.
# When using this code, feel free to use whichever scene index or initial timestep you wish. # When using this code, feel free to use whichever scene index or initial timestep you wish.
@ -204,53 +211,60 @@ class InferenceServer:
robot_present_and_future=robot_present_and_future, robot_present_and_future=robot_present_and_future,
full_dist=True) full_dist=True)
end = time.time() end = time.time()
print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start, logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
1. / (end - start), len(trajectron.nodes), 1. / (end - start), len(trajectron.nodes),
trajectron.scene_graph.get_num_edges())) trajectron.scene_graph.get_num_edges()))
detailed_preds_dict = dict() # unsure what this bit from online_prediction.py does:
for node in eval_scene.nodes: # detailed_preds_dict = dict()
if node in preds: # for node in eval_scene.nodes:
detailed_preds_dict[node] = preds[node] # if node in preds:
# detailed_preds_dict[node] = preds[node]
fig = plt.figure(figsize=(10,10)) #adapted from trajectron.visualization
ax = fig.gca() # prediction_dict provides the actual predictions
# fig, ax = plt.subplots() # histories_dict provides the trajectory used for prediction
# futures_dict is the Ground Truth, which is unvailable in an online setting
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
eval_scene.dt,
hyperparams['maximum_history_length'],
hyperparams['prediction_horizon']
)
# vis.visualize_distribution(ax, assert(len(prediction_dict.keys()) <= 1)
# dists) if len(prediction_dict.keys()) == 0:
vis.visualize_prediction(ax, return
{timestep: preds}, ts_key = list(prediction_dict.keys())[0]
eval_scene.dt,
hyperparams['maximum_history_length'],
hyperparams['prediction_horizon'])
if eval_scene.robot is not None and hyperparams['incl_robot_node']: prediction_dict = prediction_dict[ts_key]
robot_for_plotting = eval_scene.robot.get(np.array([timestep, histories_dict = histories_dict[ts_key]
timestep + hyperparams['prediction_horizon']]), futures_dict = futures_dict[ts_key]
hyperparams['state'][eval_scene.robot.type])
# robot_for_plotting += adjustment
ax.plot(robot_for_plotting[1:, 1], robot_for_plotting[1:, 0], response = {}
color='r',
linewidth=1.0, alpha=1.0)
# Current Node Position for node in histories_dict:
circle = plt.Circle((robot_for_plotting[0, 1], history = histories_dict[node]
robot_for_plotting[0, 0]), # future = futures_dict[node]
0.3, predictions = prediction_dict[node]
facecolor='r',
edgecolor='k',
lw=0.5,
zorder=3)
ax.add_artist(circle)
ax.set_xlim(-10,10) if np.isnan(history[-1]).any():
ax.set_ylim(-10,10) continue
fig.suptitle(f"frame {timestep:04d}")
fig.savefig(os.path.join(output_save_dir, f'pred_{timestep:04d}.png'))
plt.close(fig)
def run_inference_server(config, movement_q: Queue, prediction_q: Queue): response[node.id] = {
s = InferenceServer(config, movement_q, prediction_q) 'id': node.id,
'history': history.tolist(),
'predictions': predictions[0].tolist() # use batch 0
}
data = json.dumps(response)
self.prediction_socket.send_string(data)
# time.sleep(1)
# print(prediction_dict)
# print(histories_dict)
# print(futures_dict)
def run_inference_server(config):
s = InferenceServer(config)
s.run() s.run()

View file

@ -108,7 +108,7 @@ class WsRouter:
context = zmq.asyncio.Context() context = zmq.asyncio.Context()
self.trajectory_socket = context.socket(zmq.PUB) self.trajectory_socket = context.socket(zmq.PUB)
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajection_addr) self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
self.prediction_socket = context.socket(zmq.SUB) self.prediction_socket = context.socket(zmq.SUB)
self.prediction_socket.connect(config.zmq_prediction_addr) self.prediction_socket.connect(config.zmq_prediction_addr)
@ -154,7 +154,7 @@ class WsRouter:
logger.info("Starting prediction forwarder") logger.info("Starting prediction forwarder")
while True: while True:
msg = await self.prediction_socket.recv_string() msg = await self.prediction_socket.recv_string()
logger.info("Forward: ") logger.debug(f"Forward prediction message of {len(msg)} chars")
WebSocketPredictionHandler.write_to_clients(msg) WebSocketPredictionHandler.write_to_clients(msg)
def run_ws_forwarder(config: Namespace): def run_ws_forwarder(config: Namespace):

View file

@ -23,6 +23,11 @@
</canvas> </canvas>
<script>
// minified https://github.com/joewalnes/reconnecting-websocket
!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a});
</script>
<script> <script>
// map the field to coordinates of our dummy tracker // map the field to coordinates of our dummy tracker
const field_range = { x: [-10, 10], y: [-10, 10] } const field_range = { x: [-10, 10], y: [-10, 10] }
@ -57,9 +62,12 @@
function position_to_canvas_coordinate(position) { function position_to_canvas_coordinate(position) {
const x_range = field_range.x[1] - field_range.x[0] const x_range = field_range.x[1] - field_range.x[0]
const y_range = field_range.y[1] - field_range.y[0] const y_range = field_range.y[1] - field_range.y[0]
const x = Array.isArray(position) ? position[0] : position.x;
const y = Array.isArray(position) ? position[1] : position.y;
return { return {
x: (position.x - field_range.x[0]) * fieldEl.width / x_range, x: (x - field_range.x[0]) * fieldEl.width / x_range,
y: (position.y - field_range.y[0]) * fieldEl.width / y_range, y: (y - field_range.y[0]) * fieldEl.width / y_range,
} }
} }
@ -110,6 +118,7 @@
function drawFrame() { function drawFrame() {
ctx.clearRect(0, 0, fieldEl.width, fieldEl.height); ctx.clearRect(0, 0, fieldEl.width, fieldEl.height);
ctx.save(); ctx.save();
for (let id in current_data) { for (let id in current_data) {
const person = current_data[id]; const person = current_data[id];
if (person.history.length > 1) { if (person.history.length > 1) {
@ -132,6 +141,22 @@
} }
ctx.stroke(); ctx.stroke();
} }
if(person.hasOwnProperty('predictions') && person.predictions.length > 0) {
// multiple predictions can be sampled
person.predictions.forEach((prediction, i) => {
ctx.beginPath()
ctx.lineWidth = i === 1 ? 3 : 0.2;
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
// start from current position:
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
for (const position of prediction) {
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
}
ctx.stroke();
});
}
} }
ctx.restore(); ctx.restore();