applied formatter to envs
This commit is contained in:
104
envs/dmc.py
104
envs/dmc.py
@@ -3,62 +3,60 @@ import numpy as np
|
||||
|
||||
|
||||
class DeepMindControl:
|
||||
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
|
||||
domain, task = name.split("_", 1)
|
||||
if domain == "cup": # Only domain with multiple words.
|
||||
domain = "ball_in_cup"
|
||||
if isinstance(domain, str):
|
||||
from dm_control import suite
|
||||
|
||||
def __init__(self, name, action_repeat=1, size=(64, 64), camera=None):
|
||||
domain, task = name.split('_', 1)
|
||||
if domain == 'cup': # Only domain with multiple words.
|
||||
domain = 'ball_in_cup'
|
||||
if isinstance(domain, str):
|
||||
from dm_control import suite
|
||||
self._env = suite.load(domain, task)
|
||||
else:
|
||||
assert task is None
|
||||
self._env = domain()
|
||||
self._action_repeat = action_repeat
|
||||
self._size = size
|
||||
if camera is None:
|
||||
camera = dict(quadruped=2).get(domain, 0)
|
||||
self._camera = camera
|
||||
self._env = suite.load(domain, task)
|
||||
else:
|
||||
assert task is None
|
||||
self._env = domain()
|
||||
self._action_repeat = action_repeat
|
||||
self._size = size
|
||||
if camera is None:
|
||||
camera = dict(quadruped=2).get(domain, 0)
|
||||
self._camera = camera
|
||||
|
||||
@property
|
||||
def observation_space(self):
|
||||
spaces = {}
|
||||
for key, value in self._env.observation_spec().items():
|
||||
spaces[key] = gym.spaces.Box(
|
||||
-np.inf, np.inf, value.shape, dtype=np.float32)
|
||||
spaces['image'] = gym.spaces.Box(
|
||||
0, 255, self._size + (3,), dtype=np.uint8)
|
||||
return gym.spaces.Dict(spaces)
|
||||
@property
|
||||
def observation_space(self):
|
||||
spaces = {}
|
||||
for key, value in self._env.observation_spec().items():
|
||||
spaces[key] = gym.spaces.Box(-np.inf, np.inf, value.shape, dtype=np.float32)
|
||||
spaces["image"] = gym.spaces.Box(0, 255, self._size + (3,), dtype=np.uint8)
|
||||
return gym.spaces.Dict(spaces)
|
||||
|
||||
@property
|
||||
def action_space(self):
|
||||
spec = self._env.action_spec()
|
||||
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
|
||||
@property
|
||||
def action_space(self):
|
||||
spec = self._env.action_spec()
|
||||
return gym.spaces.Box(spec.minimum, spec.maximum, dtype=np.float32)
|
||||
|
||||
def step(self, action):
|
||||
assert np.isfinite(action).all(), action
|
||||
reward = 0
|
||||
for _ in range(self._action_repeat):
|
||||
time_step = self._env.step(action)
|
||||
reward += time_step.reward or 0
|
||||
if time_step.last():
|
||||
break
|
||||
obs = dict(time_step.observation)
|
||||
obs['image'] = self.render()
|
||||
# There is no terminal state in DMC
|
||||
obs['is_terminal'] = False
|
||||
done = time_step.last()
|
||||
info = {'discount': np.array(time_step.discount, np.float32)}
|
||||
return obs, reward, done, info
|
||||
def step(self, action):
|
||||
assert np.isfinite(action).all(), action
|
||||
reward = 0
|
||||
for _ in range(self._action_repeat):
|
||||
time_step = self._env.step(action)
|
||||
reward += time_step.reward or 0
|
||||
if time_step.last():
|
||||
break
|
||||
obs = dict(time_step.observation)
|
||||
obs["image"] = self.render()
|
||||
# There is no terminal state in DMC
|
||||
obs["is_terminal"] = False
|
||||
done = time_step.last()
|
||||
info = {"discount": np.array(time_step.discount, np.float32)}
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
time_step = self._env.reset()
|
||||
obs = dict(time_step.observation)
|
||||
obs['image'] = self.render()
|
||||
obs['is_terminal'] = False
|
||||
return obs
|
||||
def reset(self):
|
||||
time_step = self._env.reset()
|
||||
obs = dict(time_step.observation)
|
||||
obs["image"] = self.render()
|
||||
obs["is_terminal"] = False
|
||||
return obs
|
||||
|
||||
def render(self, *args, **kwargs):
|
||||
if kwargs.get('mode', 'rgb_array') != 'rgb_array':
|
||||
raise ValueError("Only render mode 'rgb_array' is supported.")
|
||||
return self._env.physics.render(*self._size, camera_id=self._camera)
|
||||
def render(self, *args, **kwargs):
|
||||
if kwargs.get("mode", "rgb_array") != "rgb_array":
|
||||
raise ValueError("Only render mode 'rgb_array' is supported.")
|
||||
return self._env.physics.render(*self._size, camera_id=self._camera)
|
||||
|
||||
Reference in New Issue
Block a user