added state input capability
This commit is contained in:
@@ -24,7 +24,11 @@ class DeepMindControl:
|
||||
def observation_space(self):
|
||||
spaces = {}
|
||||
for key, value in self._env.observation_spec().items():
|
||||
spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32)
|
||||
if len(value.shape) == 0:
|
||||
shape = (1,)
|
||||
else:
|
||||
shape = value.shape
|
||||
spaces[key] = gym.spaces.Box(-np.inf, np.inf, shape, dtype=np.float32)
|
||||
spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
@@ -42,6 +46,7 @@ class DeepMindControl:
|
||||
if time_step.last():
|
||||
break
|
||||
obs = dict(time_step.observation)
|
||||
obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
|
||||
obs["image"] = self.render()
|
||||
# There is no terminal state in DMC
|
||||
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
||||
@@ -53,6 +58,7 @@ class DeepMindControl:
|
||||
def reset(self):
|
||||
time_step = self._env.reset()
|
||||
obs = dict(time_step.observation)
|
||||
obs = {key: [val] if len(val.shape) == 0 else val for key, val in obs.items()}
|
||||
obs["image"] = self.render()
|
||||
obs["is_terminal"] = False if time_step.first() else time_step.discount == 0
|
||||
obs["is_first"] = time_step.first()
|
||||
|
||||
Reference in New Issue
Block a user