diff --git a/trajectron/model/dataset/preprocessing.py b/trajectron/model/dataset/preprocessing.py index c8a0e8f..844d306 100644 --- a/trajectron/model/dataset/preprocessing.py +++ b/trajectron/model/dataset/preprocessing.py @@ -151,17 +151,17 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, # Robot robot_traj_st_t = None - timestep_range_r = np.array([t, t + max_ft]) if hyperparams['incl_robot_node']: - x_node = node.get(timestep_range_r, state[node.type]) + timestep_range_r = np.array([t, t + max_ft]) if scene.non_aug_scene is not None: robot = scene.get_node_by_id(scene.non_aug_scene.robot.id) else: robot = scene.robot robot_type = robot.type - robot_traj = robot.get(timestep_range_r, state[robot_type], padding=np.nan) - robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type) - robot_traj_st_t[torch.isnan(robot_traj_st_t)] = 0.0 + robot_traj = robot.get(timestep_range_r, state[robot_type], padding=0.0) + node_state = np.zeros_like(robot_traj[0]) + node_state[:x.shape[1]] = x[-1] + robot_traj_st_t = get_relative_robot_traj(env, state, node_state, robot_traj, node.type, robot_type) # Map map_tuple = None @@ -231,4 +231,4 @@ def get_timesteps_data(env, scene, t, node_type, state, pred_state, scene_graph=scene_graph)) if len(out_timesteps) == 0: return None - return collate(batch), nodes, out_timesteps \ No newline at end of file + return collate(batch), nodes, out_timesteps