simplify dmcontrol wrappers + upgrade to gymnasium==0.29.1

This commit is contained in:
Nicklas Hansen
2024-12-10 15:16:34 -08:00
parent 32fc2bdf93
commit 6117bc427d
11 changed files with 121 additions and 298 deletions

View File

@@ -13,10 +13,9 @@ dependencies:
- pytorch-cuda=12.4 - pytorch-cuda=12.4
- torchvision=0.15.2 - torchvision=0.15.2
- pip: - pip:
- absl-py==2.1.0 - dm-control==1.0.16
- "cython<3"
- dm-control==1.0.8
- glfw==2.7.0 - glfw==2.7.0
- gymnasium==0.29.1
- ffmpeg==1.4 - ffmpeg==1.4
- imageio==2.34.1 - imageio==2.34.1
- imageio-ffmpeg==0.4.9 - imageio-ffmpeg==0.4.9
@@ -24,12 +23,9 @@ dependencies:
- hydra-core==1.3.2 - hydra-core==1.3.2
- hydra-submitit-launcher==1.2.0 - hydra-submitit-launcher==1.2.0
- submitit==1.5.1 - submitit==1.5.1
- setuptools==65.5.0
- patchelf==0.17.2.1
- omegaconf==2.3.0 - omegaconf==2.3.0
- moviepy==1.0.3 - moviepy==1.0.3
- mujoco==2.3.1 - mujoco==3.1.2
- mujoco-py==2.1.2.14
- numpy==1.24.4 - numpy==1.24.4
- tensordict-nightly==2024.11.14 - tensordict-nightly==2024.11.14
- torchrl-nightly==2024.11.14 - torchrl-nightly==2024.11.14
@@ -38,10 +34,14 @@ dependencies:
- tqdm==4.66.4 - tqdm==4.66.4
- pandas==2.0.3 - pandas==2.0.3
- wandb==0.17.4 - wandb==0.17.4
- wheel==0.38.0
#################### ####################
# Gym: # Gym:
# (unmaintained but required for maniskill2/meta-world) # (unmaintained but required for maniskill2/meta-world)
# - "cython<3"
# - wheel==0.38.0
# - setuptools==65.5.0
# - mujoco==2.3.1
# - mujoco-py==2.1.2.14
# - gym==0.21.0 # - gym==0.21.0
#################### ####################
# ManiSkill2: # ManiSkill2:

View File

@@ -1,10 +1,9 @@
from copy import deepcopy from copy import deepcopy
import warnings import warnings
import gym import gymnasium as gym
from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper from envs.wrappers.tensor import TensorWrapper
def missing_dependencies(task): def missing_dependencies(task):
@@ -70,8 +69,6 @@ def make_env(cfg):
if env is None: if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = TensorWrapper(env) env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)
try: # Dict try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
except: # Box except: # Box

View File

@@ -1,124 +1,18 @@
from collections import deque, defaultdict from collections import defaultdict, deque
from typing import Any, NamedTuple
import dm_env import gymnasium as gym
import numpy as np import numpy as np
import torch
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
from dm_control import suite from dm_control import suite
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom') suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS) suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs from envs.wrappers.timeout import Timeout
import gym
class ExtendedTimeStep(NamedTuple): def get_obs_shape(env):
step_type: Any
reward: Any
discount: Any
observation: Any
action: Any
def first(self):
return self.step_type == StepType.FIRST
def mid(self):
return self.step_type == StepType.MID
def last(self):
return self.step_type == StepType.LAST
class ActionRepeatWrapper(dm_env.Environment):
def __init__(self, env, num_repeats):
self._env = env
self._num_repeats = num_repeats
def step(self, action):
reward = 0.0
discount = 1.0
for i in range(self._num_repeats):
time_step = self._env.step(action)
reward += (time_step.reward or 0.0) * discount
discount *= time_step.discount
if time_step.last():
break
return time_step._replace(reward=reward, discount=discount)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class ActionDTypeWrapper(dm_env.Environment):
def __init__(self, env, dtype):
self._env = env
wrapped_action_spec = env.action_spec()
self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
dtype,
wrapped_action_spec.minimum,
wrapped_action_spec.maximum,
'action')
def step(self, action):
action = action.astype(self._env.action_spec().dtype)
return self._env.step(action)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._action_spec
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class ExtendedTimeStepWrapper(dm_env.Environment):
def __init__(self, env):
self._env = env
def reset(self):
time_step = self._env.reset()
return self._augment_time_step(time_step)
def step(self, action):
time_step = self._env.step(action)
return self._augment_time_step(time_step, action)
def _augment_time_step(self, time_step, action=None):
if action is None:
action_spec = self.action_spec()
action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
return ExtendedTimeStep(observation=time_step.observation,
step_type=time_step.step_type,
action=action,
reward=time_step.reward or 0.0,
discount=time_step.discount or 1.0)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def __getattr__(self, name):
return getattr(self._env, name)
class TimeStepToGymWrapper:
def __init__(self, env, domain, task):
obs_shp = [] obs_shp = []
for v in env.observation_spec().values(): for v in env.observation_spec().values():
try: try:
@@ -126,56 +20,72 @@ class TimeStepToGymWrapper:
except: except:
shp = 1 shp = 1
obs_shp.append(shp) obs_shp.append(shp)
obs_shp = (int(np.sum(obs_shp)),) return (int(np.sum(obs_shp)),)
act_shp = env.action_spec().shape
self.observation_space = gym.spaces.Box(
low=np.full( class DMControlWrapper:
obs_shp, def __init__(self, env, domain):
-np.inf,
dtype=np.float32),
high=np.full(
obs_shp,
np.inf,
dtype=np.float32),
dtype=np.float32,
)
self.action_space = gym.spaces.Box(
low=np.full(act_shp, env.action_spec().minimum),
high=np.full(act_shp, env.action_spec().maximum),
dtype=env.action_spec().dtype)
self.env = env self.env = env
self.domain = domain self.camera_id = 2 if domain == 'quadruped' else 0
self.task = task obs_shape = get_obs_shape(env)
self.max_episode_steps = 500 action_shape = env.action_spec().shape
self.t = 0 self.observation_space = gym.spaces.Box(
low=np.full(obs_shape, -np.inf, dtype=np.float32),
high=np.full(obs_shape, np.inf, dtype=np.float32),
dtype=np.float32)
self.action_space = gym.spaces.Box(
low=np.full(action_shape, env.action_spec().minimum),
high=np.full(action_shape, env.action_spec().maximum),
dtype=env.action_spec().dtype)
self.action_spec_dtype = env.action_spec().dtype
@property @property
def unwrapped(self): def unwrapped(self):
return self.env return self.env
@property
def reward_range(self):
return None
@property
def metadata(self):
return None
def _obs_to_array(self, obs): def _obs_to_array(self, obs):
return np.concatenate([v.flatten() for v in obs.values()]) return torch.from_numpy(
np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32))
def reset(self): def reset(self):
self.t = 0
return self._obs_to_array(self.env.reset().observation) return self._obs_to_array(self.env.reset().observation)
def step(self, action): def step(self, action):
self.t += 1 reward = 0
time_step = self.env.step(action) action = action.astype(self.action_spec_dtype)
return self._obs_to_array(time_step.observation), time_step.reward, time_step.last() or self.t == self.max_episode_steps, defaultdict(float) for _ in range(2):
step = self.env.step(action)
reward += step.reward
return self._obs_to_array(step.observation), reward, False, defaultdict(float)
def render(self, mode='rgb_array', width=384, height=384, camera_id=0): def render(self, width=384, height=384, camera_id=None):
camera_id = dict(quadruped=2).get(self.domain, camera_id) return self.env.physics.render(height, width, camera_id or self.camera_id)
return self.env.physics.render(height, width, camera_id)
class Pixels(gym.Wrapper):
def __init__(self, env, cfg, num_frames=3, size=64):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(num_frames*3, size, size), dtype=np.uint8)
self._frames = deque([], maxlen=num_frames)
self._size = size
def _get_obs(self, is_reset=False):
frame = self.env.render(width=self._size, height=self._size).transpose(2, 0, 1)
num_frames = self._frames.maxlen if is_reset else 1
for _ in range(num_frames):
self._frames.append(frame)
return torch.from_numpy(np.concatenate(self._frames))
def reset(self):
self.env.reset()
return self._get_obs(is_reset=True)
def step(self, action):
_, reward, done, info = self.env.step(action)
return self._get_obs(), reward, done, info
def make_env(cfg): def make_env(cfg):
@@ -192,9 +102,9 @@ def make_env(cfg):
task, task,
task_kwargs={'random': cfg.seed}, task_kwargs={'random': cfg.seed},
visualize_reward=False) visualize_reward=False)
env = ActionDTypeWrapper(env, np.float32)
env = ActionRepeatWrapper(env, 2)
env = action_scale.Wrapper(env, minimum=-1., maximum=1.) env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
env = ExtendedTimeStepWrapper(env) env = DMControlWrapper(env, domain)
env = TimeStepToGymWrapper(env, domain, task) if cfg.obs == 'rgb':
env = Pixels(env, cfg)
env = Timeout(env, max_episode_steps=500)
return env return env

View File

@@ -1,6 +1,6 @@
import gym import gymnasium as gym
import numpy as np import numpy as np
from envs.wrappers.time_limit import TimeLimit from envs.wrappers.timeout import Timeout
import mani_skill2.envs import mani_skill2.envs
@@ -74,6 +74,6 @@ def make_env(cfg):
render_camera_cfgs=dict(width=384, height=384), render_camera_cfgs=dict(width=384, height=384),
) )
env = ManiSkillWrapper(env, cfg) env = ManiSkillWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100) env = Timeout(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps env.max_episode_steps = env._max_episode_steps
return env return env

View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
import gym import gym
from envs.wrappers.time_limit import TimeLimit from envs.wrappers.timeout import Timeout
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
@@ -47,6 +47,6 @@ def make_env(cfg):
assert cfg.obs == 'state', 'This task only supports state observations.' assert cfg.obs == 'state', 'This task only supports state observations.'
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed) env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
env = MetaWorldWrapper(env, cfg) env = MetaWorldWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100) env = Timeout(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps env.max_episode_steps = env._max_episode_steps
return env return env

View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
import gym import gymnasium as gym
from envs.wrappers.time_limit import TimeLimit from envs.wrappers.timeout import Timeout
MYOSUITE_TASKS = { MYOSUITE_TASKS = {
@@ -53,6 +53,6 @@ def make_env(cfg):
from myosuite.utils import gym as gym_utils from myosuite.utils import gym as gym_utils
env = gym_utils.make(MYOSUITE_TASKS[cfg.task]) env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg) env = MyoSuiteWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100) env = Timeout(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps env.max_episode_steps = env._max_episode_steps
return env return env

View File

@@ -1,4 +1,4 @@
import gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch

View File

@@ -1,38 +0,0 @@
from collections import deque
import gym
import numpy as np
import torch
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Compatible with DMControl environments.
"""
def __init__(self, cfg, env, num_frames=3, render_size=64):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8
)
self._frames = deque([], maxlen=num_frames)
self._render_size = render_size
def _get_obs(self):
frame = self.env.render(
mode='rgb_array', width=self._render_size, height=self._render_size
).transpose(2, 0, 1)
self._frames.append(frame)
return torch.from_numpy(np.concatenate(self._frames))
def reset(self):
self.env.reset()
for _ in range(self._frames.maxlen):
obs = self._get_obs()
return obs
def step(self, action):
_, reward, done, info = self.env.step(action)
return self._get_obs(), reward, done, info

View File

@@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
import gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
@@ -17,6 +17,7 @@ class TensorWrapper(gym.Wrapper):
return torch.from_numpy(self.action_space.sample().astype(np.float32)) return torch.from_numpy(self.action_space.sample().astype(np.float32))
def _try_f32_tensor(self, x): def _try_f32_tensor(self, x):
if isinstance(x, np.ndarray):
x = torch.from_numpy(x) x = torch.from_numpy(x)
if x.dtype == torch.float64: if x.dtype == torch.float64:
x = x.float() x = x.float()

View File

@@ -1,72 +0,0 @@
"""
Wrapper for limiting the time steps of an environment.
Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py
"""
from typing import Optional
import gym
class TimeLimit(gym.Wrapper):
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
Example:
>>> from gym.envs.classic_control import CartPoleEnv
>>> from gym.wrappers import TimeLimit
>>> env = CartPoleEnv()
>>> env = TimeLimit(env, max_episode_steps=1000)
"""
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
"""
super().__init__(env)
if max_episode_steps is None and self.env.spec is not None:
max_episode_steps = env.spec.max_episode_steps
if self.env.spec is not None:
self.env.spec.max_episode_steps = max_episode_steps
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None
def step(self, action):
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
Args:
action: The environment step action
Returns:
The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
when truncated (the number of steps elapsed >= max episode steps) or
"TimeLimit.truncated"=False if the environment terminated
"""
observation, reward, done, info = self.env.step(action)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
# TimeLimit.truncated key may have been already set by the environment
# do not overwrite it
episode_truncated = not done or info.get("TimeLimit.truncated", False)
info["TimeLimit.truncated"] = episode_truncated
done = True
return observation, reward, done, info
def reset(self, **kwargs):
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
Args:
**kwargs: The kwargs to reset the environment with
Returns:
The reset environment
"""
self._elapsed_steps = 0
return self.env.reset(**kwargs)

View File

@@ -0,0 +1,25 @@
import gymnasium as gym
class Timeout(gym.Wrapper):
"""
Wrapper for enforcing a time limit on the environment.
"""
def __init__(self, env, max_episode_steps):
super().__init__(env)
self._max_episode_steps = max_episode_steps
@property
def max_episode_steps(self):
return self._max_episode_steps
def reset(self, **kwargs):
self._t = 0
return self.env.reset(**kwargs)
def step(self, action):
obs, reward, done, info = self.env.step(action)
self._t += 1
done = done or self._t >= self.max_episode_steps
return obs, reward, done, info