This commit is contained in:
Ruben van de Ven 2023-12-06 10:25:01 +01:00
parent ec9bb357fd
commit f3b8e031c1

View file

@ -270,10 +270,10 @@ class PredictionServer:
x = history[:, 0] x = history[:, 0]
y = history[:, 1] y = history[:, 1]
# TODO: calculate dt based on input # TODO: calculate dt based on input
vx = derivative_of(x, 0.2) #eval_scene.dt vx = derivative_of(x, 0.1) #eval_scene.dt
vy = derivative_of(y, 0.2) vy = derivative_of(y, 0.1)
ax = derivative_of(vx, 0.2) ax = derivative_of(vx, 0.1)
ay = derivative_of(vy, 0.2) ay = derivative_of(vy, 0.1)
data_dict = {('position', 'x'): x[:], data_dict = {('position', 'x'): x[:],
('position', 'y'): y[:], ('position', 'y'): y[:],
@ -325,7 +325,7 @@ class PredictionServer:
dists, preds = trajectron.incremental_forward(input_dict, dists, preds = trajectron.incremental_forward(input_dict,
maps, maps,
prediction_horizon=25, # TODO: make variable prediction_horizon=25, # TODO: make variable
num_samples=20, # TODO: make variable num_samples=5, # TODO: make variable
robot_present_and_future=robot_present_and_future, robot_present_and_future=robot_present_and_future,
full_dist=True) full_dist=True)
end = time.time() end = time.time()
@ -343,6 +343,7 @@ class PredictionServer:
# prediction_dict provides the actual predictions # prediction_dict provides the actual predictions
# histories_dict provides the trajectory used for prediction # histories_dict provides the trajectory used for prediction
# futures_dict is the Ground Truth, which is unvailable in an online setting # futures_dict is the Ground Truth, which is unvailable in an online setting
prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds}, prediction_dict, histories_dict, futures_dict = prediction_output_to_trajectories({timestep: preds},
eval_scene.dt, eval_scene.dt,
hyperparams['maximum_history_length'], hyperparams['maximum_history_length'],
@ -371,6 +372,8 @@ class PredictionServer:
response[node.id] = { response[node.id] = {
'id': node.id, 'id': node.id,
'det_conf': trajectory_data[node.id]['det_conf'],
'bbox': trajectory_data[node.id]['bbox'],
'history': history.tolist(), 'history': history.tolist(),
'predictions': predictions[0].tolist() # use batch 0 'predictions': predictions[0].tolist() # use batch 0
} }