Small batch preprocessing fix

Ensuring that the robot future standardization is with respect to a node's current position.
This commit is contained in:
Boris Ivanovic 2022-06-14 17:44:35 -04:00 committed by GitHub
parent 58248eea68
commit 58b9763d02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 6 deletions

View File

@ -151,17 +151,17 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state,
# Robot
robot_traj_st_t = None
timestep_range_r = np.array([t, t + max_ft])
if hyperparams['incl_robot_node']:
x_node = node.get(timestep_range_r, state[node.type])
timestep_range_r = np.array([t, t + max_ft])
if scene.non_aug_scene is not None:
robot = scene.get_node_by_id(scene.non_aug_scene.robot.id)
else:
robot = scene.robot
robot_type = robot.type
robot_traj = robot.get(timestep_range_r, state[robot_type], padding=np.nan)
robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type)
robot_traj_st_t[torch.isnan(robot_traj_st_t)] = 0.0
robot_traj = robot.get(timestep_range_r, state[robot_type], padding=0.0)
node_state = np.zeros_like(robot_traj[0])
node_state[:x.shape[1]] = x[-1]
robot_traj_st_t = get_relative_robot_traj(env, state, node_state, robot_traj, node.type, robot_type)
# Map
map_tuple = None
@ -231,4 +231,4 @@ def get_timesteps_data(env, scene, t, node_type, state, pred_state,
scene_graph=scene_graph))
if len(out_timesteps) == 0:
return None
return collate(batch), nodes, out_timesteps
return collate(batch), nodes, out_timesteps