env v0.12
This commit is contained in:
@@ -47,17 +47,32 @@ class MZGymWrapper:
|
||||
else:
|
||||
return {self._act_key: self._env.action_space}
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
img_shape = self._size + ((1,) if self._gray else (3,))
|
||||
return gym.spaces.Dict(
|
||||
{
|
||||
"image": gym.spaces.Box(0, 255, img_shape, np.uint8),
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
space = self._env.action_space
|
||||
space.discrete = True
|
||||
return space
|
||||
|
||||
def step(self, action):
|
||||
if not self._act_is_dict:
|
||||
action = action[self._act_key]
|
||||
# if not self._act_is_dict:
|
||||
# action = action[self._act_key]
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
if not self._obs_is_dict:
|
||||
obs = {self._obs_key: obs}
|
||||
obs['reward'] = float(reward)
|
||||
# obs['reward'] = float(reward)
|
||||
obs['is_first'] = False
|
||||
obs['is_last'] = done
|
||||
obs['is_terminal'] = info.get('is_terminal', done)
|
||||
return obs
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self._env.reset()
|
||||
|
||||
@@ -77,8 +77,8 @@ class CollectDataset:
|
||||
dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision]
|
||||
elif np.issubdtype(value.dtype, np.uint8):
|
||||
dtype = np.uint8
|
||||
elif np.issubdtype(value.dtype, np.bool):
|
||||
dtype = np.bool
|
||||
elif np.issubdtype(value.dtype, np.bool_):
|
||||
dtype = np.bool_
|
||||
else:
|
||||
raise NotImplementedError(value.dtype)
|
||||
return value.astype(dtype)
|
||||
@@ -96,6 +96,7 @@ class TimeLimit:
|
||||
def step(self, action):
|
||||
assert self._step is not None, "Must reset environment."
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
# teets = self._env.step(action)
|
||||
self._step += 1
|
||||
if self._step >= self._duration:
|
||||
done = True
|
||||
|
||||
Reference in New Issue
Block a user