mem maze env ok 1.2
This commit is contained in:
@@ -77,8 +77,8 @@ class CollectDataset:
|
||||
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.uint8):
|
||||
dtype = np.uint8
|
||||
elif np.issubdtype(value.dtype, np.bool_):
|
||||
dtype = np.bool_
|
||||
elif np.issubdtype(value.dtype, np.bool):
|
||||
dtype = np.bool
|
||||
else:
|
||||
raise NotImplementedError(value.dtype)
|
||||
return value.astype(dtype)
|
||||
@@ -168,40 +168,6 @@ class OneHotAction:
|
||||
reference[index] = 1.0
|
||||
return reference
|
||||
|
||||
class OneHotAction2:
|
||||
def __init__(self, env):
|
||||
assert isinstance(env.action_space, gym.spaces.Discrete)
|
||||
self._env = env
|
||||
self._random = np.random.RandomState()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
shape = (self._env.action_space.n,)
|
||||
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
|
||||
space.sample = self._sample_action
|
||||
space.discrete = True
|
||||
return space
|
||||
|
||||
def step(self, action):
|
||||
index = np.argmax(action).astype(int)
|
||||
reference = np.zeros_like(action)
|
||||
reference[index] = 1
|
||||
# if not np.allclose(reference, action):
|
||||
# raise ValueError(f"Invalid one-hot action:\n{action}")
|
||||
return self._env.step(index)
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def _sample_action(self):
|
||||
actions = self._env.action_space.n
|
||||
index = self._random.randint(0, actions)
|
||||
reference = np.zeros(actions, dtype=np.float32)
|
||||
reference[index] = 1.0
|
||||
return reference
|
||||
|
||||
class RewardObs:
|
||||
def __init__(self, env):
|
||||
|
||||
Reference in New Issue
Block a user