75 lines
3.2 KiB
Python
75 lines
3.2 KiB
Python
|
import numpy as np
|
||
|
from enum import Enum
|
||
|
from itertools import product
|
||
|
|
||
|
|
||
|
class Environment(object):
|
||
|
def __init__(self, node_type_list, standardization, scenes=None, attention_radius=None):
|
||
|
self.scenes = scenes
|
||
|
self.node_type_list = node_type_list
|
||
|
self.attention_radius = attention_radius
|
||
|
self.NodeType = Enum('NodeType', node_type_list)
|
||
|
|
||
|
self.standardization = standardization
|
||
|
|
||
|
def get_edge_types(self):
|
||
|
return [e for e in product([node_type for node_type in self.NodeType], repeat=2)]
|
||
|
|
||
|
def edge_type_str(self, edge_type):
|
||
|
return edge_type[0].name + '-' + edge_type[1].name
|
||
|
|
||
|
def get_standardize_params(self, state, node_type):
|
||
|
standardize_mean_list = list()
|
||
|
standardize_std_list = list()
|
||
|
for entity, dims in state.items():
|
||
|
for dim in dims:
|
||
|
standardize_mean_list.append(self.standardization[node_type.name][entity][dim]['mean'])
|
||
|
standardize_std_list.append(self.standardization[node_type.name][entity][dim]['std'])
|
||
|
standardize_mean = np.stack(standardize_mean_list)
|
||
|
standardize_std = np.stack(standardize_std_list)
|
||
|
|
||
|
return standardize_mean, standardize_std
|
||
|
|
||
|
def standardize(self, array, state, node_type, mean=None, std=None):
|
||
|
if mean is None and std is None:
|
||
|
mean, std = self.get_standardize_params(state, node_type)
|
||
|
elif mean is None and std is not None:
|
||
|
mean, _ = self.get_standardize_params(state, node_type)
|
||
|
elif mean is not None and std is None:
|
||
|
_, std = self.get_standardize_params(state, node_type)
|
||
|
return np.where(np.isnan(array), np.array(np.nan), (array - mean) / std)
|
||
|
|
||
|
def unstandardize(self, array, state, node_type, mean=None, std=None):
|
||
|
if mean is None and std is None:
|
||
|
mean, std = self.get_standardize_params(state, node_type)
|
||
|
elif mean is None and std is not None:
|
||
|
mean, _ = self.get_standardize_params(state, node_type)
|
||
|
elif mean is not None and std is None:
|
||
|
_, std = self.get_standardize_params(state, node_type)
|
||
|
return array * std + mean
|
||
|
|
||
|
# These two functions have to be implemented as pickle can not handle dynamic enums
|
||
|
def __getstate__(self):
|
||
|
for scene in self.scenes:
|
||
|
for node in scene.nodes:
|
||
|
node.type = node.type.name
|
||
|
attention_radius_no_enum = dict()
|
||
|
for key, value in self.attention_radius.items():
|
||
|
attention_radius_no_enum[(key[0].name, key[1].name)] = value
|
||
|
self.attention_radius = attention_radius_no_enum
|
||
|
self.NodeType = None
|
||
|
state = self.__dict__.copy()
|
||
|
return state
|
||
|
|
||
|
def __setstate__(self, state):
|
||
|
self.__dict__.update(state)
|
||
|
self.NodeType = Enum('NodeType', self.node_type_list)
|
||
|
for scene in self.scenes:
|
||
|
for node in scene.nodes:
|
||
|
node.type = getattr(self.NodeType, node.type)
|
||
|
attention_radius_enum = dict()
|
||
|
for key, value in self.attention_radius.items():
|
||
|
attention_radius_enum[(getattr(self.NodeType, key[0]), getattr(self.NodeType, key[1]))] = value
|
||
|
self.attention_radius = attention_radius_enum
|
||
|
|