Render predictions to browser
This commit is contained in:
parent
9ba283ca9b
commit
ee57604b30
4 changed files with 90 additions and 57 deletions
|
@ -13,12 +13,6 @@ def start():
|
|||
logging.basicConfig(
|
||||
level=loglevel,
|
||||
)
|
||||
# rootLogger = logging.getLogger()
|
||||
# rootLogger.setLevel(loglevel)
|
||||
|
||||
movement_q = Queue()
|
||||
prediction_q = Queue()
|
||||
|
||||
|
||||
# instantiating process with arguments
|
||||
procs = [
|
||||
|
@ -26,7 +20,7 @@ def start():
|
|||
]
|
||||
if not args.bypass_prediction:
|
||||
procs.append(
|
||||
Process(target=run_inference_server, args=(args, movement_q, prediction_q)),
|
||||
Process(target=run_inference_server, args=(args,)),
|
||||
)
|
||||
|
||||
logger.info("start")
|
||||
|
|
|
@ -9,13 +9,13 @@ import dill
|
|||
import random
|
||||
import pathlib
|
||||
import numpy as np
|
||||
import trajectron.visualization as vis
|
||||
from trajectron.utils import prediction_output_to_trajectories
|
||||
from trajectron.model.online.online_trajectron import OnlineTrajectron
|
||||
from trajectron.model.model_registrar import ModelRegistrar
|
||||
from trajectron.environment import Environment, Scene
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
import zmq
|
||||
|
||||
logger = logging.getLogger("trajpred.inference")
|
||||
|
||||
|
@ -102,10 +102,17 @@ def get_maps_for_input(input_dict, scene, hyperparams):
|
|||
|
||||
|
||||
class InferenceServer:
|
||||
def __init__(self, config: dict, movement_q: Queue, prediction_q: Queue):
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.movement_q = movement_q
|
||||
self.prediction_q = prediction_q
|
||||
|
||||
context = zmq.Context()
|
||||
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
||||
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
||||
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
||||
|
||||
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
|
||||
self.prediction_socket.bind(config.zmq_prediction_addr)
|
||||
print(self.prediction_socket)
|
||||
|
||||
def run(self):
|
||||
|
||||
|
@ -153,7 +160,7 @@ class InferenceServer:
|
|||
for scene in eval_env.scenes:
|
||||
scene.add_robot_from_nodes(eval_env.robot_type)
|
||||
|
||||
print('Loaded data from %s' % (self.config.eval_data_dict,))
|
||||
logger.info('Loaded data from %s' % (self.config.eval_data_dict,))
|
||||
|
||||
# Creating a dummy environment with a single scene that contains information about the world.
|
||||
# When using this code, feel free to use whichever scene index or initial timestep you wish.
|
||||
|
@ -204,53 +211,60 @@ class InferenceServer:
|
|||
robot_present_and_future=robot_present_and_future,
|
||||
full_dist=True)
|
||||
end = time.time()
|
||||
print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
|
||||
logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
|
||||
1. / (end - start), len(trajectron.nodes),
|
||||
trajectron.scene_graph.get_num_edges()))
|
||||
|
||||
detailed_preds_dict = dict()
|
||||
for node in eval_scene.nodes:
|
||||
if node in preds:
|
||||
detailed_preds_dict[node] = preds[node]
|
||||
# unsure what this bit from online_prediction.py does:
|
||||
# detailed_preds_dict = dict()
|
||||
# for node in eval_scene.nodes:
|
||||
# if node in preds:
|
||||
# detailed_preds_dict[node] = preds[node]
|
||||
|
||||
fig = plt.figure(figsize=(10,10))
|
||||
ax = fig.gca()
|
||||
# fig, ax = plt.subplots()
|
||||
#adapted from trajectron.visualization
|
||||
# prediction_dict provides the actual predictions
|
||||
# histories_dict provides the trajectory used for prediction
|
||||
# futures_dict is the Ground Truth, which is unvailable in an online setting
|
||||
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
|
||||
eval_scene.dt,
|
||||
hyperparams['maximum_history_length'],
|
||||
hyperparams['prediction_horizon']
|
||||
)
|
||||
|
||||
assert(len(prediction_dict.keys()) <= 1)
|
||||
if len(prediction_dict.keys()) == 0:
|
||||
return
|
||||
ts_key = list(prediction_dict.keys())[0]
|
||||
|
||||
prediction_dict = prediction_dict[ts_key]
|
||||
histories_dict = histories_dict[ts_key]
|
||||
futures_dict = futures_dict[ts_key]
|
||||
|
||||
response = {}
|
||||
|
||||
# vis.visualize_distribution(ax,
|
||||
# dists)
|
||||
vis.visualize_prediction(ax,
|
||||
{timestep: preds},
|
||||
eval_scene.dt,
|
||||
hyperparams['maximum_history_length'],
|
||||
hyperparams['prediction_horizon'])
|
||||
for node in histories_dict:
|
||||
history = histories_dict[node]
|
||||
# future = futures_dict[node]
|
||||
predictions = prediction_dict[node]
|
||||
|
||||
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
|
||||
robot_for_plotting = eval_scene.robot.get(np.array([timestep,
|
||||
timestep + hyperparams['prediction_horizon']]),
|
||||
hyperparams['state'][eval_scene.robot.type])
|
||||
# robot_for_plotting += adjustment
|
||||
if np.isnan(history[-1]).any():
|
||||
continue
|
||||
|
||||
ax.plot(robot_for_plotting[1:, 1], robot_for_plotting[1:, 0],
|
||||
color='r',
|
||||
linewidth=1.0, alpha=1.0)
|
||||
response[node.id] = {
|
||||
'id': node.id,
|
||||
'history': history.tolist(),
|
||||
'predictions': predictions[0].tolist() # use batch 0
|
||||
}
|
||||
|
||||
# Current Node Position
|
||||
circle = plt.Circle((robot_for_plotting[0, 1],
|
||||
robot_for_plotting[0, 0]),
|
||||
0.3,
|
||||
facecolor='r',
|
||||
edgecolor='k',
|
||||
lw=0.5,
|
||||
zorder=3)
|
||||
ax.add_artist(circle)
|
||||
data = json.dumps(response)
|
||||
self.prediction_socket.send_string(data)
|
||||
# time.sleep(1)
|
||||
# print(prediction_dict)
|
||||
# print(histories_dict)
|
||||
# print(futures_dict)
|
||||
|
||||
ax.set_xlim(-10,10)
|
||||
ax.set_ylim(-10,10)
|
||||
fig.suptitle(f"frame {timestep:04d}")
|
||||
fig.savefig(os.path.join(output_save_dir, f'pred_{timestep:04d}.png'))
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def run_inference_server(config, movement_q: Queue, prediction_q: Queue):
|
||||
s = InferenceServer(config, movement_q, prediction_q)
|
||||
def run_inference_server(config):
|
||||
s = InferenceServer(config)
|
||||
s.run()
|
|
@ -108,7 +108,7 @@ class WsRouter:
|
|||
|
||||
context = zmq.asyncio.Context()
|
||||
self.trajectory_socket = context.socket(zmq.PUB)
|
||||
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajection_addr)
|
||||
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
|
||||
|
||||
self.prediction_socket = context.socket(zmq.SUB)
|
||||
self.prediction_socket.connect(config.zmq_prediction_addr)
|
||||
|
@ -154,7 +154,7 @@ class WsRouter:
|
|||
logger.info("Starting prediction forwarder")
|
||||
while True:
|
||||
msg = await self.prediction_socket.recv_string()
|
||||
logger.info("Forward: ")
|
||||
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||
WebSocketPredictionHandler.write_to_clients(msg)
|
||||
|
||||
def run_ws_forwarder(config: Namespace):
|
||||
|
|
|
@ -23,6 +23,11 @@
|
|||
|
||||
</canvas>
|
||||
|
||||
<script>
|
||||
// minified https://github.com/joewalnes/reconnecting-websocket
|
||||
!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a});
|
||||
</script>
|
||||
|
||||
<script>
|
||||
// map the field to coordinates of our dummy tracker
|
||||
const field_range = { x: [-10, 10], y: [-10, 10] }
|
||||
|
@ -57,9 +62,12 @@
|
|||
function position_to_canvas_coordinate(position) {
|
||||
const x_range = field_range.x[1] - field_range.x[0]
|
||||
const y_range = field_range.y[1] - field_range.y[0]
|
||||
|
||||
const x = Array.isArray(position) ? position[0] : position.x;
|
||||
const y = Array.isArray(position) ? position[1] : position.y;
|
||||
return {
|
||||
x: (position.x - field_range.x[0]) * fieldEl.width / x_range,
|
||||
y: (position.y - field_range.y[0]) * fieldEl.width / y_range,
|
||||
x: (x - field_range.x[0]) * fieldEl.width / x_range,
|
||||
y: (y - field_range.y[0]) * fieldEl.width / y_range,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,6 +118,7 @@
|
|||
function drawFrame() {
|
||||
ctx.clearRect(0, 0, fieldEl.width, fieldEl.height);
|
||||
ctx.save();
|
||||
|
||||
for (let id in current_data) {
|
||||
const person = current_data[id];
|
||||
if (person.history.length > 1) {
|
||||
|
@ -121,7 +130,7 @@
|
|||
5, //radius
|
||||
0, 2 * Math.PI);
|
||||
ctx.fill()
|
||||
|
||||
|
||||
ctx.beginPath()
|
||||
ctx.lineWidth = 3;
|
||||
ctx.strokeStyle = "#325FA2";
|
||||
|
@ -132,6 +141,22 @@
|
|||
}
|
||||
ctx.stroke();
|
||||
}
|
||||
|
||||
if(person.hasOwnProperty('predictions') && person.predictions.length > 0) {
|
||||
// multiple predictions can be sampled
|
||||
person.predictions.forEach((prediction, i) => {
|
||||
ctx.beginPath()
|
||||
ctx.lineWidth = i === 1 ? 3 : 0.2;
|
||||
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
|
||||
|
||||
// start from current position:
|
||||
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
|
||||
for (const position of prediction) {
|
||||
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
|
||||
}
|
||||
ctx.stroke();
|
||||
});
|
||||
}
|
||||
}
|
||||
ctx.restore();
|
||||
|
||||
|
|
Loading…
Reference in a new issue