from torch.utils import data import numpy as np from .preprocessing import get_node_timestep_data class EnvironmentDataset(object): def __init__(self, env, state, pred_state, node_freq_mult, scene_freq_mult, hyperparams, **kwargs): self.env = env self.state = state self.pred_state = pred_state self.hyperparams = hyperparams self.max_ht = self.hyperparams['maximum_history_length'] self.max_ft = kwargs['min_future_timesteps'] self.node_type_datasets = list() self._augment = False for node_type in env.NodeType: if node_type not in hyperparams['pred_state']: continue self.node_type_datasets.append(NodeTypeDataset(env, node_type, state, pred_state, node_freq_mult, scene_freq_mult, hyperparams, **kwargs)) @property def augment(self): return self._augment @augment.setter def augment(self, value): self._augment = value for node_type_dataset in self.node_type_datasets: node_type_dataset.augment = value def __iter__(self): return iter(self.node_type_datasets) class NodeTypeDataset(data.Dataset): def __init__(self, env, node_type, state, pred_state, node_freq_mult, scene_freq_mult, hyperparams, augment=False, **kwargs): self.env = env self.state = state self.pred_state = pred_state self.hyperparams = hyperparams self.max_ht = self.hyperparams['maximum_history_length'] self.max_ft = kwargs['min_future_timesteps'] self.augment = augment self.node_type = node_type self.index = self.index_env(node_freq_mult, scene_freq_mult, **kwargs) self.len = len(self.index) self.edge_types = [edge_type for edge_type in env.get_edge_types() if edge_type[0] is node_type] def index_env(self, node_freq_mult, scene_freq_mult, **kwargs): index = list() for scene in self.env.scenes: present_node_dict = scene.present_nodes(np.arange(0, scene.timesteps), type=self.node_type, **kwargs) for t, nodes in present_node_dict.items(): for node in nodes: index += [(scene, t, node)] *\ (scene.frequency_multiplier if scene_freq_mult else 1) *\ (node.frequency_multiplier if node_freq_mult else 1) return index def __len__(self): return self.len def __getitem__(self, i): (scene, t, node) = self.index[i] if self.augment: scene = scene.augment() node = scene.get_node_by_id(node.id) return get_node_timestep_data(self.env, scene, t, node, self.state, self.pred_state, self.edge_types, self.max_ht, self.max_ft, self.hyperparams)