simplify dmcontrol wrappers + upgrade to gymnasium==0.29.1

This commit is contained in:
Nicklas Hansen
2024-12-10 15:16:34 -08:00
parent 32fc2bdf93
commit 6117bc427d
11 changed files with 121 additions and 298 deletions

View File

@@ -13,10 +13,9 @@ dependencies:
- pytorch-cuda=12.4
- torchvision=0.15.2
- pip:
- absl-py==2.1.0
- "cython<3"
- dm-control==1.0.8
- dm-control==1.0.16
- glfw==2.7.0
- gymnasium==0.29.1
- ffmpeg==1.4
- imageio==2.34.1
- imageio-ffmpeg==0.4.9
@@ -24,12 +23,9 @@ dependencies:
- hydra-core==1.3.2
- hydra-submitit-launcher==1.2.0
- submitit==1.5.1
- setuptools==65.5.0
- patchelf==0.17.2.1
- omegaconf==2.3.0
- moviepy==1.0.3
- mujoco==2.3.1
- mujoco-py==2.1.2.14
- mujoco==3.1.2
- numpy==1.24.4
- tensordict-nightly==2024.11.14
- torchrl-nightly==2024.11.14
@@ -38,10 +34,14 @@ dependencies:
- tqdm==4.66.4
- pandas==2.0.3
- wandb==0.17.4
- wheel==0.38.0
####################
# Gym:
# (unmaintained but required for maniskill2/meta-world)
# - "cython<3"
# - wheel==0.38.0
# - setuptools==65.5.0
# - mujoco==2.3.1
# - mujoco-py==2.1.2.14
# - gym==0.21.0
####################
# ManiSkill2:

View File

@@ -1,10 +1,9 @@
from copy import deepcopy
import warnings
import gym
import gymnasium as gym
from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper
def missing_dependencies(task):
@@ -70,8 +69,6 @@ def make_env(cfg):
if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)
try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
except: # Box

View File

@@ -1,181 +1,91 @@
from collections import deque, defaultdict
from typing import Any, NamedTuple
import dm_env
from collections import defaultdict, deque
import gymnasium as gym
import numpy as np
import torch
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
from dm_control import suite
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs
import gym
from envs.wrappers.timeout import Timeout
class ExtendedTimeStep(NamedTuple):
step_type: Any
reward: Any
discount: Any
observation: Any
action: Any
def first(self):
return self.step_type == StepType.FIRST
def mid(self):
return self.step_type == StepType.MID
def last(self):
return self.step_type == StepType.LAST
def get_obs_shape(env):
obs_shp = []
for v in env.observation_spec().values():
try:
shp = np.prod(v.shape)
except:
shp = 1
obs_shp.append(shp)
return (int(np.sum(obs_shp)),)
class ActionRepeatWrapper(dm_env.Environment):
def __init__(self, env, num_repeats):
self._env = env
self._num_repeats = num_repeats
def step(self, action):
reward = 0.0
discount = 1.0
for i in range(self._num_repeats):
time_step = self._env.step(action)
reward += (time_step.reward or 0.0) * discount
discount *= time_step.discount
if time_step.last():
break
return time_step._replace(reward=reward, discount=discount)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class ActionDTypeWrapper(dm_env.Environment):
def __init__(self, env, dtype):
self._env = env
wrapped_action_spec = env.action_spec()
self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
dtype,
wrapped_action_spec.minimum,
wrapped_action_spec.maximum,
'action')
def step(self, action):
action = action.astype(self._env.action_spec().dtype)
return self._env.step(action)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._action_spec
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class ExtendedTimeStepWrapper(dm_env.Environment):
def __init__(self, env):
self._env = env
def reset(self):
time_step = self._env.reset()
return self._augment_time_step(time_step)
def step(self, action):
time_step = self._env.step(action)
return self._augment_time_step(time_step, action)
def _augment_time_step(self, time_step, action=None):
if action is None:
action_spec = self.action_spec()
action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
return ExtendedTimeStep(observation=time_step.observation,
step_type=time_step.step_type,
action=action,
reward=time_step.reward or 0.0,
discount=time_step.discount or 1.0)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def __getattr__(self, name):
return getattr(self._env, name)
class TimeStepToGymWrapper:
def __init__(self, env, domain, task):
obs_shp = []
for v in env.observation_spec().values():
try:
shp = np.prod(v.shape)
except:
shp = 1
obs_shp.append(shp)
obs_shp = (int(np.sum(obs_shp)),)
act_shp = env.action_spec().shape
self.observation_space = gym.spaces.Box(
low=np.full(
obs_shp,
-np.inf,
dtype=np.float32),
high=np.full(
obs_shp,
np.inf,
dtype=np.float32),
dtype=np.float32,
)
self.action_space = gym.spaces.Box(
low=np.full(act_shp, env.action_spec().minimum),
high=np.full(act_shp, env.action_spec().maximum),
dtype=env.action_spec().dtype)
class DMControlWrapper:
def __init__(self, env, domain):
self.env = env
self.domain = domain
self.task = task
self.max_episode_steps = 500
self.t = 0
self.camera_id = 2 if domain == 'quadruped' else 0
obs_shape = get_obs_shape(env)
action_shape = env.action_spec().shape
self.observation_space = gym.spaces.Box(
low=np.full(obs_shape, -np.inf, dtype=np.float32),
high=np.full(obs_shape, np.inf, dtype=np.float32),
dtype=np.float32)
self.action_space = gym.spaces.Box(
low=np.full(action_shape, env.action_spec().minimum),
high=np.full(action_shape, env.action_spec().maximum),
dtype=env.action_spec().dtype)
self.action_spec_dtype = env.action_spec().dtype
@property
def unwrapped(self):
return self.env
@property
def reward_range(self):
return None
@property
def metadata(self):
return None
def _obs_to_array(self, obs):
return np.concatenate([v.flatten() for v in obs.values()])
return torch.from_numpy(
np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32))
def reset(self):
return self._obs_to_array(self.env.reset().observation)
def step(self, action):
reward = 0
action = action.astype(self.action_spec_dtype)
for _ in range(2):
step = self.env.step(action)
reward += step.reward
return self._obs_to_array(step.observation), reward, False, defaultdict(float)
def render(self, width=384, height=384, camera_id=None):
return self.env.physics.render(height, width, camera_id or self.camera_id)
class Pixels(gym.Wrapper):
def __init__(self, env, cfg, num_frames=3, size=64):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(num_frames*3, size, size), dtype=np.uint8)
self._frames = deque([], maxlen=num_frames)
self._size = size
def _get_obs(self, is_reset=False):
frame = self.env.render(width=self._size, height=self._size).transpose(2, 0, 1)
num_frames = self._frames.maxlen if is_reset else 1
for _ in range(num_frames):
self._frames.append(frame)
return torch.from_numpy(np.concatenate(self._frames))
def reset(self):
self.t = 0
return self._obs_to_array(self.env.reset().observation)
def step(self, action):
self.t += 1
time_step = self.env.step(action)
return self._obs_to_array(time_step.observation), time_step.reward, time_step.last() or self.t == self.max_episode_steps, defaultdict(float)
self.env.reset()
return self._get_obs(is_reset=True)
def render(self, mode='rgb_array', width=384, height=384, camera_id=0):
camera_id = dict(quadruped=2).get(self.domain, camera_id)
return self.env.physics.render(height, width, camera_id)
def step(self, action):
_, reward, done, info = self.env.step(action)
return self._get_obs(), reward, done, info
def make_env(cfg):
@@ -192,9 +102,9 @@ def make_env(cfg):
task,
task_kwargs={'random': cfg.seed},
visualize_reward=False)
env = ActionDTypeWrapper(env, np.float32)
env = ActionRepeatWrapper(env, 2)
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
env = ExtendedTimeStepWrapper(env)
env = TimeStepToGymWrapper(env, domain, task)
env = DMControlWrapper(env, domain)
if cfg.obs == 'rgb':
env = Pixels(env, cfg)
env = Timeout(env, max_episode_steps=500)
return env

View File

@@ -1,6 +1,6 @@
import gym
import gymnasium as gym
import numpy as np
from envs.wrappers.time_limit import TimeLimit
from envs.wrappers.timeout import Timeout
import mani_skill2.envs
@@ -74,6 +74,6 @@ def make_env(cfg):
render_camera_cfgs=dict(width=384, height=384),
)
env = ManiSkillWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)
env = Timeout(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps
return env

View File

@@ -1,6 +1,6 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
from envs.wrappers.timeout import Timeout
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
@@ -47,6 +47,6 @@ def make_env(cfg):
assert cfg.obs == 'state', 'This task only supports state observations.'
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
env = MetaWorldWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)
env = Timeout(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps
return env

View File

@@ -1,6 +1,6 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
import gymnasium as gym
from envs.wrappers.timeout import Timeout
MYOSUITE_TASKS = {
@@ -53,6 +53,6 @@ def make_env(cfg):
from myosuite.utils import gym as gym_utils
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)
env = Timeout(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps
return env

View File

@@ -1,4 +1,4 @@
import gym
import gymnasium as gym
import numpy as np
import torch

View File

@@ -1,38 +0,0 @@
from collections import deque
import gym
import numpy as np
import torch
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Compatible with DMControl environments.
"""
def __init__(self, cfg, env, num_frames=3, render_size=64):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8
)
self._frames = deque([], maxlen=num_frames)
self._render_size = render_size
def _get_obs(self):
frame = self.env.render(
mode='rgb_array', width=self._render_size, height=self._render_size
).transpose(2, 0, 1)
self._frames.append(frame)
return torch.from_numpy(np.concatenate(self._frames))
def reset(self):
self.env.reset()
for _ in range(self._frames.maxlen):
obs = self._get_obs()
return obs
def step(self, action):
_, reward, done, info = self.env.step(action)
return self._get_obs(), reward, done, info

View File

@@ -1,6 +1,6 @@
from collections import defaultdict
import gym
import gymnasium as gym
import numpy as np
import torch
@@ -17,9 +17,10 @@ class TensorWrapper(gym.Wrapper):
return torch.from_numpy(self.action_space.sample().astype(np.float32))
def _try_f32_tensor(self, x):
x = torch.from_numpy(x)
if x.dtype == torch.float64:
x = x.float()
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if x.dtype == torch.float64:
x = x.float()
return x
def _obs_to_tensor(self, obs):

View File

@@ -1,72 +0,0 @@
"""
Wrapper for limiting the time steps of an environment.
Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py
"""
from typing import Optional
import gym
class TimeLimit(gym.Wrapper):
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
Example:
>>> from gym.envs.classic_control import CartPoleEnv
>>> from gym.wrappers import TimeLimit
>>> env = CartPoleEnv()
>>> env = TimeLimit(env, max_episode_steps=1000)
"""
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
"""
super().__init__(env)
if max_episode_steps is None and self.env.spec is not None:
max_episode_steps = env.spec.max_episode_steps
if self.env.spec is not None:
self.env.spec.max_episode_steps = max_episode_steps
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None
def step(self, action):
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
Args:
action: The environment step action
Returns:
The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
when truncated (the number of steps elapsed >= max episode steps) or
"TimeLimit.truncated"=False if the environment terminated
"""
observation, reward, done, info = self.env.step(action)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
# TimeLimit.truncated key may have been already set by the environment
# do not overwrite it
episode_truncated = not done or info.get("TimeLimit.truncated", False)
info["TimeLimit.truncated"] = episode_truncated
done = True
return observation, reward, done, info
def reset(self, **kwargs):
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
Args:
**kwargs: The kwargs to reset the environment with
Returns:
The reset environment
"""
self._elapsed_steps = 0
return self.env.reset(**kwargs)

View File

@@ -0,0 +1,25 @@
import gymnasium as gym
class Timeout(gym.Wrapper):
"""
Wrapper for enforcing a time limit on the environment.
"""
def __init__(self, env, max_episode_steps):
super().__init__(env)
self._max_episode_steps = max_episode_steps
@property
def max_episode_steps(self):
return self._max_episode_steps
def reset(self, **kwargs):
self._t = 0
return self.env.reset(**kwargs)
def step(self, action):
obs, reward, done, info = self.env.step(action)
self._t += 1
done = done or self._t >= self.max_episode_steps
return obs, reward, done, info