Render predictions to browser
This commit is contained in:
parent
9ba283ca9b
commit
ee57604b30
4 changed files with 90 additions and 57 deletions
|
@ -13,12 +13,6 @@ def start():
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=loglevel,
|
level=loglevel,
|
||||||
)
|
)
|
||||||
# rootLogger = logging.getLogger()
|
|
||||||
# rootLogger.setLevel(loglevel)
|
|
||||||
|
|
||||||
movement_q = Queue()
|
|
||||||
prediction_q = Queue()
|
|
||||||
|
|
||||||
|
|
||||||
# instantiating process with arguments
|
# instantiating process with arguments
|
||||||
procs = [
|
procs = [
|
||||||
|
@ -26,7 +20,7 @@ def start():
|
||||||
]
|
]
|
||||||
if not args.bypass_prediction:
|
if not args.bypass_prediction:
|
||||||
procs.append(
|
procs.append(
|
||||||
Process(target=run_inference_server, args=(args, movement_q, prediction_q)),
|
Process(target=run_inference_server, args=(args,)),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("start")
|
logger.info("start")
|
||||||
|
|
|
@ -9,13 +9,13 @@ import dill
|
||||||
import random
|
import random
|
||||||
import pathlib
|
import pathlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import trajectron.visualization as vis
|
from trajectron.utils import prediction_output_to_trajectories
|
||||||
from trajectron.model.online.online_trajectron import OnlineTrajectron
|
from trajectron.model.online.online_trajectron import OnlineTrajectron
|
||||||
from trajectron.model.model_registrar import ModelRegistrar
|
from trajectron.model.model_registrar import ModelRegistrar
|
||||||
from trajectron.environment import Environment, Scene
|
from trajectron.environment import Environment, Scene
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
logger = logging.getLogger("trajpred.inference")
|
logger = logging.getLogger("trajpred.inference")
|
||||||
|
|
||||||
|
@ -102,10 +102,17 @@ def get_maps_for_input(input_dict, scene, hyperparams):
|
||||||
|
|
||||||
|
|
||||||
class InferenceServer:
|
class InferenceServer:
|
||||||
def __init__(self, config: dict, movement_q: Queue, prediction_q: Queue):
|
def __init__(self, config: dict):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.movement_q = movement_q
|
|
||||||
self.prediction_q = prediction_q
|
context = zmq.Context()
|
||||||
|
self.trajectory_socket: zmq.Socket = context.socket(zmq.SUB)
|
||||||
|
self.trajectory_socket.connect(config.zmq_trajectory_addr)
|
||||||
|
self.trajectory_socket.setsockopt(zmq.SUBSCRIBE, b'')
|
||||||
|
|
||||||
|
self.prediction_socket: zmq.Socket = context.socket(zmq.PUB)
|
||||||
|
self.prediction_socket.bind(config.zmq_prediction_addr)
|
||||||
|
print(self.prediction_socket)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
||||||
|
@ -153,7 +160,7 @@ class InferenceServer:
|
||||||
for scene in eval_env.scenes:
|
for scene in eval_env.scenes:
|
||||||
scene.add_robot_from_nodes(eval_env.robot_type)
|
scene.add_robot_from_nodes(eval_env.robot_type)
|
||||||
|
|
||||||
print('Loaded data from %s' % (self.config.eval_data_dict,))
|
logger.info('Loaded data from %s' % (self.config.eval_data_dict,))
|
||||||
|
|
||||||
# Creating a dummy environment with a single scene that contains information about the world.
|
# Creating a dummy environment with a single scene that contains information about the world.
|
||||||
# When using this code, feel free to use whichever scene index or initial timestep you wish.
|
# When using this code, feel free to use whichever scene index or initial timestep you wish.
|
||||||
|
@ -204,53 +211,60 @@ class InferenceServer:
|
||||||
robot_present_and_future=robot_present_and_future,
|
robot_present_and_future=robot_present_and_future,
|
||||||
full_dist=True)
|
full_dist=True)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
|
logger.info("t=%d: took %.2f s (= %.2f Hz) w/ %d nodes and %d edges" % (timestep, end - start,
|
||||||
1. / (end - start), len(trajectron.nodes),
|
1. / (end - start), len(trajectron.nodes),
|
||||||
trajectron.scene_graph.get_num_edges()))
|
trajectron.scene_graph.get_num_edges()))
|
||||||
|
|
||||||
detailed_preds_dict = dict()
|
# unsure what this bit from online_prediction.py does:
|
||||||
for node in eval_scene.nodes:
|
# detailed_preds_dict = dict()
|
||||||
if node in preds:
|
# for node in eval_scene.nodes:
|
||||||
detailed_preds_dict[node] = preds[node]
|
# if node in preds:
|
||||||
|
# detailed_preds_dict[node] = preds[node]
|
||||||
|
|
||||||
fig = plt.figure(figsize=(10,10))
|
#adapted from trajectron.visualization
|
||||||
ax = fig.gca()
|
# prediction_dict provides the actual predictions
|
||||||
# fig, ax = plt.subplots()
|
# histories_dict provides the trajectory used for prediction
|
||||||
|
# futures_dict is the Ground Truth, which is unvailable in an online setting
|
||||||
|
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
|
||||||
|
eval_scene.dt,
|
||||||
|
hyperparams['maximum_history_length'],
|
||||||
|
hyperparams['prediction_horizon']
|
||||||
|
)
|
||||||
|
|
||||||
|
assert(len(prediction_dict.keys()) <= 1)
|
||||||
|
if len(prediction_dict.keys()) == 0:
|
||||||
|
return
|
||||||
|
ts_key = list(prediction_dict.keys())[0]
|
||||||
|
|
||||||
|
prediction_dict = prediction_dict[ts_key]
|
||||||
|
histories_dict = histories_dict[ts_key]
|
||||||
|
futures_dict = futures_dict[ts_key]
|
||||||
|
|
||||||
|
response = {}
|
||||||
|
|
||||||
# vis.visualize_distribution(ax,
|
for node in histories_dict:
|
||||||
# dists)
|
history = histories_dict[node]
|
||||||
vis.visualize_prediction(ax,
|
# future = futures_dict[node]
|
||||||
{timestep: preds},
|
predictions = prediction_dict[node]
|
||||||
eval_scene.dt,
|
|
||||||
hyperparams['maximum_history_length'],
|
|
||||||
hyperparams['prediction_horizon'])
|
|
||||||
|
|
||||||
if eval_scene.robot is not None and hyperparams['incl_robot_node']:
|
if np.isnan(history[-1]).any():
|
||||||
robot_for_plotting = eval_scene.robot.get(np.array([timestep,
|
continue
|
||||||
timestep + hyperparams['prediction_horizon']]),
|
|
||||||
hyperparams['state'][eval_scene.robot.type])
|
|
||||||
# robot_for_plotting += adjustment
|
|
||||||
|
|
||||||
ax.plot(robot_for_plotting[1:, 1], robot_for_plotting[1:, 0],
|
response[node.id] = {
|
||||||
color='r',
|
'id': node.id,
|
||||||
linewidth=1.0, alpha=1.0)
|
'history': history.tolist(),
|
||||||
|
'predictions': predictions[0].tolist() # use batch 0
|
||||||
|
}
|
||||||
|
|
||||||
# Current Node Position
|
data = json.dumps(response)
|
||||||
circle = plt.Circle((robot_for_plotting[0, 1],
|
self.prediction_socket.send_string(data)
|
||||||
robot_for_plotting[0, 0]),
|
# time.sleep(1)
|
||||||
0.3,
|
# print(prediction_dict)
|
||||||
facecolor='r',
|
# print(histories_dict)
|
||||||
edgecolor='k',
|
# print(futures_dict)
|
||||||
lw=0.5,
|
|
||||||
zorder=3)
|
|
||||||
ax.add_artist(circle)
|
|
||||||
|
|
||||||
ax.set_xlim(-10,10)
|
|
||||||
ax.set_ylim(-10,10)
|
|
||||||
fig.suptitle(f"frame {timestep:04d}")
|
|
||||||
fig.savefig(os.path.join(output_save_dir, f'pred_{timestep:04d}.png'))
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
def run_inference_server(config, movement_q: Queue, prediction_q: Queue):
|
def run_inference_server(config):
|
||||||
s = InferenceServer(config, movement_q, prediction_q)
|
s = InferenceServer(config)
|
||||||
s.run()
|
s.run()
|
|
@ -108,7 +108,7 @@ class WsRouter:
|
||||||
|
|
||||||
context = zmq.asyncio.Context()
|
context = zmq.asyncio.Context()
|
||||||
self.trajectory_socket = context.socket(zmq.PUB)
|
self.trajectory_socket = context.socket(zmq.PUB)
|
||||||
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajection_addr)
|
self.trajectory_socket.bind(config.zmq_prediction_addr if config.bypass_prediction else config.zmq_trajectory_addr)
|
||||||
|
|
||||||
self.prediction_socket = context.socket(zmq.SUB)
|
self.prediction_socket = context.socket(zmq.SUB)
|
||||||
self.prediction_socket.connect(config.zmq_prediction_addr)
|
self.prediction_socket.connect(config.zmq_prediction_addr)
|
||||||
|
@ -154,7 +154,7 @@ class WsRouter:
|
||||||
logger.info("Starting prediction forwarder")
|
logger.info("Starting prediction forwarder")
|
||||||
while True:
|
while True:
|
||||||
msg = await self.prediction_socket.recv_string()
|
msg = await self.prediction_socket.recv_string()
|
||||||
logger.info("Forward: ")
|
logger.debug(f"Forward prediction message of {len(msg)} chars")
|
||||||
WebSocketPredictionHandler.write_to_clients(msg)
|
WebSocketPredictionHandler.write_to_clients(msg)
|
||||||
|
|
||||||
def run_ws_forwarder(config: Namespace):
|
def run_ws_forwarder(config: Namespace):
|
||||||
|
|
|
@ -23,6 +23,11 @@
|
||||||
|
|
||||||
</canvas>
|
</canvas>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
// minified https://github.com/joewalnes/reconnecting-websocket
|
||||||
|
!function(a,b){"function"==typeof define&&define.amd?define([],b):"undefined"!=typeof module&&module.exports?module.exports=b():a.ReconnectingWebSocket=b()}(this,function(){function a(b,c,d){function l(a,b){var c=document.createEvent("CustomEvent");return c.initCustomEvent(a,!1,!1,b),c}var e={debug:!1,automaticOpen:!0,reconnectInterval:1e3,maxReconnectInterval:3e4,reconnectDecay:1.5,timeoutInterval:2e3};d||(d={});for(var f in e)this[f]="undefined"!=typeof d[f]?d[f]:e[f];this.url=b,this.reconnectAttempts=0,this.readyState=WebSocket.CONNECTING,this.protocol=null;var h,g=this,i=!1,j=!1,k=document.createElement("div");k.addEventListener("open",function(a){g.onopen(a)}),k.addEventListener("close",function(a){g.onclose(a)}),k.addEventListener("connecting",function(a){g.onconnecting(a)}),k.addEventListener("message",function(a){g.onmessage(a)}),k.addEventListener("error",function(a){g.onerror(a)}),this.addEventListener=k.addEventListener.bind(k),this.removeEventListener=k.removeEventListener.bind(k),this.dispatchEvent=k.dispatchEvent.bind(k),this.open=function(b){h=new WebSocket(g.url,c||[]),b||k.dispatchEvent(l("connecting")),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","attempt-connect",g.url);var d=h,e=setTimeout(function(){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","connection-timeout",g.url),j=!0,d.close(),j=!1},g.timeoutInterval);h.onopen=function(){clearTimeout(e),(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onopen",g.url),g.protocol=h.protocol,g.readyState=WebSocket.OPEN,g.reconnectAttempts=0;var d=l("open");d.isReconnect=b,b=!1,k.dispatchEvent(d)},h.onclose=function(c){if(clearTimeout(e),h=null,i)g.readyState=WebSocket.CLOSED,k.dispatchEvent(l("close"));else{g.readyState=WebSocket.CONNECTING;var d=l("connecting");d.code=c.code,d.reason=c.reason,d.wasClean=c.wasClean,k.dispatchEvent(d),b||j||((g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onclose",g.url),k.dispatchEvent(l("close")));var e=g.reconnectInterval*Math.pow(g.reconnectDecay,g.reconnectAttempts);setTimeout(function(){g.reconnectAttempts++,g.open(!0)},e>g.maxReconnectInterval?g.maxReconnectInterval:e)}},h.onmessage=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onmessage",g.url,b.data);var c=l("message");c.data=b.data,k.dispatchEvent(c)},h.onerror=function(b){(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","onerror",g.url,b),k.dispatchEvent(l("error"))}},1==this.automaticOpen&&this.open(!1),this.send=function(b){if(h)return(g.debug||a.debugAll)&&console.debug("ReconnectingWebSocket","send",g.url,b),h.send(b);throw"INVALID_STATE_ERR : Pausing to reconnect websocket"},this.close=function(a,b){"undefined"==typeof a&&(a=1e3),i=!0,h&&h.close(a,b)},this.refresh=function(){h&&h.close()}}return a.prototype.onopen=function(){},a.prototype.onclose=function(){},a.prototype.onconnecting=function(){},a.prototype.onmessage=function(){},a.prototype.onerror=function(){},a.debugAll=!1,a.CONNECTING=WebSocket.CONNECTING,a.OPEN=WebSocket.OPEN,a.CLOSING=WebSocket.CLOSING,a.CLOSED=WebSocket.CLOSED,a});
|
||||||
|
</script>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
// map the field to coordinates of our dummy tracker
|
// map the field to coordinates of our dummy tracker
|
||||||
const field_range = { x: [-10, 10], y: [-10, 10] }
|
const field_range = { x: [-10, 10], y: [-10, 10] }
|
||||||
|
@ -57,9 +62,12 @@
|
||||||
function position_to_canvas_coordinate(position) {
|
function position_to_canvas_coordinate(position) {
|
||||||
const x_range = field_range.x[1] - field_range.x[0]
|
const x_range = field_range.x[1] - field_range.x[0]
|
||||||
const y_range = field_range.y[1] - field_range.y[0]
|
const y_range = field_range.y[1] - field_range.y[0]
|
||||||
|
|
||||||
|
const x = Array.isArray(position) ? position[0] : position.x;
|
||||||
|
const y = Array.isArray(position) ? position[1] : position.y;
|
||||||
return {
|
return {
|
||||||
x: (position.x - field_range.x[0]) * fieldEl.width / x_range,
|
x: (x - field_range.x[0]) * fieldEl.width / x_range,
|
||||||
y: (position.y - field_range.y[0]) * fieldEl.width / y_range,
|
y: (y - field_range.y[0]) * fieldEl.width / y_range,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -110,6 +118,7 @@
|
||||||
function drawFrame() {
|
function drawFrame() {
|
||||||
ctx.clearRect(0, 0, fieldEl.width, fieldEl.height);
|
ctx.clearRect(0, 0, fieldEl.width, fieldEl.height);
|
||||||
ctx.save();
|
ctx.save();
|
||||||
|
|
||||||
for (let id in current_data) {
|
for (let id in current_data) {
|
||||||
const person = current_data[id];
|
const person = current_data[id];
|
||||||
if (person.history.length > 1) {
|
if (person.history.length > 1) {
|
||||||
|
@ -121,7 +130,7 @@
|
||||||
5, //radius
|
5, //radius
|
||||||
0, 2 * Math.PI);
|
0, 2 * Math.PI);
|
||||||
ctx.fill()
|
ctx.fill()
|
||||||
|
|
||||||
ctx.beginPath()
|
ctx.beginPath()
|
||||||
ctx.lineWidth = 3;
|
ctx.lineWidth = 3;
|
||||||
ctx.strokeStyle = "#325FA2";
|
ctx.strokeStyle = "#325FA2";
|
||||||
|
@ -132,6 +141,22 @@
|
||||||
}
|
}
|
||||||
ctx.stroke();
|
ctx.stroke();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(person.hasOwnProperty('predictions') && person.predictions.length > 0) {
|
||||||
|
// multiple predictions can be sampled
|
||||||
|
person.predictions.forEach((prediction, i) => {
|
||||||
|
ctx.beginPath()
|
||||||
|
ctx.lineWidth = i === 1 ? 3 : 0.2;
|
||||||
|
ctx.strokeStyle = i === 1 ? "#ff0000" : "#ccaaaa";
|
||||||
|
|
||||||
|
// start from current position:
|
||||||
|
ctx.moveTo(...coord_as_list(position_to_canvas_coordinate(person.history[person.history.length - 1])));
|
||||||
|
for (const position of prediction) {
|
||||||
|
ctx.lineTo(...coord_as_list(position_to_canvas_coordinate(position)))
|
||||||
|
}
|
||||||
|
ctx.stroke();
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ctx.restore();
|
ctx.restore();
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue