Small batch preprocessing fix
Ensuring that the robot future standardization is with respect to a node's current position.
This commit is contained in:
parent
58248eea68
commit
58b9763d02
1 changed files with 6 additions and 6 deletions
|
@ -151,17 +151,17 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state,
|
|||
|
||||
# Robot
|
||||
robot_traj_st_t = None
|
||||
timestep_range_r = np.array([t, t + max_ft])
|
||||
if hyperparams['incl_robot_node']:
|
||||
x_node = node.get(timestep_range_r, state[node.type])
|
||||
timestep_range_r = np.array([t, t + max_ft])
|
||||
if scene.non_aug_scene is not None:
|
||||
robot = scene.get_node_by_id(scene.non_aug_scene.robot.id)
|
||||
else:
|
||||
robot = scene.robot
|
||||
robot_type = robot.type
|
||||
robot_traj = robot.get(timestep_range_r, state[robot_type], padding=np.nan)
|
||||
robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type)
|
||||
robot_traj_st_t[torch.isnan(robot_traj_st_t)] = 0.0
|
||||
robot_traj = robot.get(timestep_range_r, state[robot_type], padding=0.0)
|
||||
node_state = np.zeros_like(robot_traj[0])
|
||||
node_state[:x.shape[1]] = x[-1]
|
||||
robot_traj_st_t = get_relative_robot_traj(env, state, node_state, robot_traj, node.type, robot_type)
|
||||
|
||||
# Map
|
||||
map_tuple = None
|
||||
|
@ -231,4 +231,4 @@ def get_timesteps_data(env, scene, t, node_type, state, pred_state,
|
|||
scene_graph=scene_graph))
|
||||
if len(out_timesteps) == 0:
|
||||
return None
|
||||
return collate(batch), nodes, out_timesteps
|
||||
return collate(batch), nodes, out_timesteps
|
||||
|
|
Loading…
Reference in a new issue