2020-04-06 03:43:49 +02:00
|
|
|
import orjson
|
2020-01-13 19:55:45 +01:00
|
|
|
import numpy as np
|
|
|
|
from itertools import product
|
2020-04-06 03:43:49 +02:00
|
|
|
from .node_type import NodeTypeEnum
|
2020-01-13 19:55:45 +01:00
|
|
|
|
|
|
|
|
|
|
|
class Environment(object):
|
2020-04-06 03:43:49 +02:00
|
|
|
def __init__(self, node_type_list, standardization, scenes=None, attention_radius=None, robot_type=None):
|
2020-01-13 19:55:45 +01:00
|
|
|
self.scenes = scenes
|
|
|
|
self.node_type_list = node_type_list
|
|
|
|
self.attention_radius = attention_radius
|
2020-04-06 03:43:49 +02:00
|
|
|
self.NodeType = NodeTypeEnum(node_type_list)
|
|
|
|
self.robot_type = robot_type
|
2020-01-13 19:55:45 +01:00
|
|
|
|
|
|
|
self.standardization = standardization
|
2020-04-06 03:43:49 +02:00
|
|
|
self.standardize_param_memo = dict()
|
2020-01-13 19:55:45 +01:00
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
self._scenes_resample_prop = None
|
2020-01-13 19:55:45 +01:00
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
def get_edge_types(self):
|
|
|
|
return list(product(self.NodeType, repeat=2))
|
2020-01-13 19:55:45 +01:00
|
|
|
|
|
|
|
def get_standardize_params(self, state, node_type):
|
2020-04-06 03:43:49 +02:00
|
|
|
memo_key = (orjson.dumps(state), node_type)
|
|
|
|
if memo_key in self.standardize_param_memo:
|
|
|
|
return self.standardize_param_memo[memo_key]
|
|
|
|
|
2020-01-13 19:55:45 +01:00
|
|
|
standardize_mean_list = list()
|
|
|
|
standardize_std_list = list()
|
|
|
|
for entity, dims in state.items():
|
|
|
|
for dim in dims:
|
2020-04-06 03:43:49 +02:00
|
|
|
standardize_mean_list.append(self.standardization[node_type][entity][dim]['mean'])
|
|
|
|
standardize_std_list.append(self.standardization[node_type][entity][dim]['std'])
|
2020-01-13 19:55:45 +01:00
|
|
|
standardize_mean = np.stack(standardize_mean_list)
|
|
|
|
standardize_std = np.stack(standardize_std_list)
|
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
self.standardize_param_memo[memo_key] = (standardize_mean, standardize_std)
|
2020-01-13 19:55:45 +01:00
|
|
|
return standardize_mean, standardize_std
|
|
|
|
|
|
|
|
def standardize(self, array, state, node_type, mean=None, std=None):
|
|
|
|
if mean is None and std is None:
|
|
|
|
mean, std = self.get_standardize_params(state, node_type)
|
|
|
|
elif mean is None and std is not None:
|
|
|
|
mean, _ = self.get_standardize_params(state, node_type)
|
|
|
|
elif mean is not None and std is None:
|
|
|
|
_, std = self.get_standardize_params(state, node_type)
|
|
|
|
return np.where(np.isnan(array), np.array(np.nan), (array - mean) / std)
|
|
|
|
|
|
|
|
def unstandardize(self, array, state, node_type, mean=None, std=None):
|
|
|
|
if mean is None and std is None:
|
|
|
|
mean, std = self.get_standardize_params(state, node_type)
|
|
|
|
elif mean is None and std is not None:
|
|
|
|
mean, _ = self.get_standardize_params(state, node_type)
|
|
|
|
elif mean is not None and std is None:
|
|
|
|
_, std = self.get_standardize_params(state, node_type)
|
|
|
|
return array * std + mean
|
|
|
|
|
2020-04-06 03:43:49 +02:00
|
|
|
@property
|
|
|
|
def scenes_resample_prop(self):
|
|
|
|
if self._scenes_resample_prop is None:
|
|
|
|
self._scenes_resample_prop = np.array([scene.resample_prob for scene in self.scenes])
|
|
|
|
self._scenes_resample_prop = self._scenes_resample_prop / np.sum(self._scenes_resample_prop)
|
|
|
|
return self._scenes_resample_prop
|
2020-01-13 19:55:45 +01:00
|
|
|
|