env v0.12

This commit is contained in:
张德祥
2023-06-13 21:39:04 +08:00
parent 5038a91aad
commit b9120a7440
4 changed files with 24 additions and 8 deletions

View File

@@ -47,17 +47,32 @@ class MZGymWrapper:
else:
return {self._act_key: self._env.action_space}
@property
def observation_space(self):
img_shape = self._size + ((1,) if self._gray else (3,))
return gym.spaces.Dict(
{
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
}
)
@property
def action_space(self):
space = self._env.action_space
space.discrete = True
return space
def step(self, action):
if not self._act_is_dict:
action = action[self._act_key]
# if not self._act_is_dict:
# action = action[self._act_key]
obs, reward, done, info = self._env.step(action)
if not self._obs_is_dict:
obs = {self._obs_key: obs}
obs['reward'] = float(reward)
# obs['reward'] = float(reward)
obs['is_first'] = False
obs['is_last'] = done
obs['is_terminal'] = info.get('is_terminal', done)
return obs
return obs, reward, done, info
def reset(self):
obs = self._env.reset()

View File

@@ -77,8 +77,8 @@ class CollectDataset:
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
elif np.issubdtype(value.dtype, np.uint8):
dtype = np.uint8
elif np.issubdtype(value.dtype, np.bool):
dtype = np.bool
elif np.issubdtype(value.dtype, np.bool_):
dtype = np.bool_
else:
raise NotImplementedError(value.dtype)
return value.astype(dtype)
@@ -96,6 +96,7 @@ class TimeLimit:
def step(self, action):
assert self._step is not None, "Must reset environment."
obs, reward, done, info = self._env.step(action)
# teets = self._env.step(action)
self._step += 1
if self._step >= self._duration:
done = True