Small batch preprocessing fix

Ensuring that the robot future standardization is with respect to a node's current position.
This commit is contained in:
Boris Ivanovic 2022-06-14 17:44:35 -04:00 committed by GitHub
parent 58248eea68
commit 58b9763d02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -151,17 +151,17 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state,
# Robot # Robot
robot_traj_st_t = None robot_traj_st_t = None
timestep_range_r = np.array([t, t + max_ft])
if hyperparams['incl_robot_node']: if hyperparams['incl_robot_node']:
x_node = node.get(timestep_range_r, state[node.type]) timestep_range_r = np.array([t, t + max_ft])
if scene.non_aug_scene is not None: if scene.non_aug_scene is not None:
robot = scene.get_node_by_id(scene.non_aug_scene.robot.id) robot = scene.get_node_by_id(scene.non_aug_scene.robot.id)
else: else:
robot = scene.robot robot = scene.robot
robot_type = robot.type robot_type = robot.type
robot_traj = robot.get(timestep_range_r, state[robot_type], padding=np.nan) robot_traj = robot.get(timestep_range_r, state[robot_type], padding=0.0)
robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type) node_state = np.zeros_like(robot_traj[0])
robot_traj_st_t[torch.isnan(robot_traj_st_t)] = 0.0 node_state[:x.shape[1]] = x[-1]
robot_traj_st_t = get_relative_robot_traj(env, state, node_state, robot_traj, node.type, robot_type)
# Map # Map
map_tuple = None map_tuple = None
@ -231,4 +231,4 @@ def get_timesteps_data(env, scene, t, node_type, state, pred_state,
scene_graph=scene_graph)) scene_graph=scene_graph))
if len(out_timesteps) == 0: if len(out_timesteps) == 0:
return None return None
return collate(batch), nodes, out_timesteps return collate(batch), nodes, out_timesteps