mem maze env ok 1

This commit is contained in:
张德祥
2023-06-17 23:29:53 +08:00
parent 1cf0149c10
commit ea446adaf4
4 changed files with 48 additions and 588 deletions

View File

@@ -96,7 +96,6 @@ class TimeLimit:
def step(self, action):
assert self._step is not None, "Must reset environment."
obs, reward, done, info = self._env.step(action)
# teets = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True
@@ -151,6 +150,41 @@ class OneHotAction:
space.discrete = True
return space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action):
raise ValueError(f"Invalid one-hot action:\n{action}")
return self._env.step(index)
def reset(self):
return self._env.reset()
def _sample_action(self):
actions = self._env.action_space.n
index = self._random.randint(0, actions)
reference = np.zeros(actions, dtype=np.float32)
reference[index] = 1.0
return reference
class OneHotAction2:
def __init__(self, env):
assert isinstance(env.action_space, gym.spaces.Discrete)
self._env = env
self._random = np.random.RandomState()
def __getattr__(self, name):
return getattr(self._env, name)
@property
def action_space(self):
shape = (self._env.action_space.n,)
space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32)
space.sample = self._sample_action
space.discrete = True
return space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
@@ -169,7 +203,6 @@ class OneHotAction:
reference[index] = 1.0
return reference
class RewardObs:
def __init__(self, env):
self._env = env