simplify dmcontrol wrappers + upgrade to gymnasium==0.29.1
This commit is contained in:
@@ -13,10 +13,9 @@ dependencies:
|
||||
- pytorch-cuda=12.4
|
||||
- torchvision=0.15.2
|
||||
- pip:
|
||||
- absl-py==2.1.0
|
||||
- "cython<3"
|
||||
- dm-control==1.0.8
|
||||
- dm-control==1.0.16
|
||||
- glfw==2.7.0
|
||||
- gymnasium==0.29.1
|
||||
- ffmpeg==1.4
|
||||
- imageio==2.34.1
|
||||
- imageio-ffmpeg==0.4.9
|
||||
@@ -24,12 +23,9 @@ dependencies:
|
||||
- hydra-core==1.3.2
|
||||
- hydra-submitit-launcher==1.2.0
|
||||
- submitit==1.5.1
|
||||
- setuptools==65.5.0
|
||||
- patchelf==0.17.2.1
|
||||
- omegaconf==2.3.0
|
||||
- moviepy==1.0.3
|
||||
- mujoco==2.3.1
|
||||
- mujoco-py==2.1.2.14
|
||||
- mujoco==3.1.2
|
||||
- numpy==1.24.4
|
||||
- tensordict-nightly==2024.11.14
|
||||
- torchrl-nightly==2024.11.14
|
||||
@@ -38,10 +34,14 @@ dependencies:
|
||||
- tqdm==4.66.4
|
||||
- pandas==2.0.3
|
||||
- wandb==0.17.4
|
||||
- wheel==0.38.0
|
||||
####################
|
||||
# Gym:
|
||||
# (unmaintained but required for maniskill2/meta-world)
|
||||
# - "cython<3"
|
||||
# - wheel==0.38.0
|
||||
# - setuptools==65.5.0
|
||||
# - mujoco==2.3.1
|
||||
# - mujoco-py==2.1.2.14
|
||||
# - gym==0.21.0
|
||||
####################
|
||||
# ManiSkill2:
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from copy import deepcopy
|
||||
import warnings
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
|
||||
from envs.wrappers.multitask import MultitaskWrapper
|
||||
from envs.wrappers.pixels import PixelWrapper
|
||||
from envs.wrappers.tensor import TensorWrapper
|
||||
|
||||
def missing_dependencies(task):
|
||||
@@ -70,8 +69,6 @@ def make_env(cfg):
|
||||
if env is None:
|
||||
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
||||
env = TensorWrapper(env)
|
||||
if cfg.get('obs', 'state') == 'rgb':
|
||||
env = PixelWrapper(cfg, env)
|
||||
try: # Dict
|
||||
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
||||
except: # Box
|
||||
|
||||
@@ -1,181 +1,91 @@
|
||||
from collections import deque, defaultdict
|
||||
from typing import Any, NamedTuple
|
||||
import dm_env
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
|
||||
from dm_control import suite
|
||||
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
||||
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
||||
from dm_control.suite.wrappers import action_scale
|
||||
from dm_env import StepType, specs
|
||||
import gym
|
||||
from envs.wrappers.timeout import Timeout
|
||||
|
||||
|
||||
class ExtendedTimeStep(NamedTuple):
|
||||
step_type: Any
|
||||
reward: Any
|
||||
discount: Any
|
||||
observation: Any
|
||||
action: Any
|
||||
|
||||
def first(self):
|
||||
return self.step_type == StepType.FIRST
|
||||
|
||||
def mid(self):
|
||||
return self.step_type == StepType.MID
|
||||
|
||||
def last(self):
|
||||
return self.step_type == StepType.LAST
|
||||
def get_obs_shape(env):
|
||||
obs_shp = []
|
||||
for v in env.observation_spec().values():
|
||||
try:
|
||||
shp = np.prod(v.shape)
|
||||
except:
|
||||
shp = 1
|
||||
obs_shp.append(shp)
|
||||
return (int(np.sum(obs_shp)),)
|
||||
|
||||
|
||||
class ActionRepeatWrapper(dm_env.Environment):
|
||||
def __init__(self, env, num_repeats):
|
||||
self._env = env
|
||||
self._num_repeats = num_repeats
|
||||
|
||||
def step(self, action):
|
||||
reward = 0.0
|
||||
discount = 1.0
|
||||
for i in range(self._num_repeats):
|
||||
time_step = self._env.step(action)
|
||||
reward += (time_step.reward or 0.0) * discount
|
||||
discount *= time_step.discount
|
||||
if time_step.last():
|
||||
break
|
||||
|
||||
return time_step._replace(reward=reward, discount=discount)
|
||||
|
||||
def observation_spec(self):
|
||||
return self._env.observation_spec()
|
||||
|
||||
def action_spec(self):
|
||||
return self._env.action_spec()
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
class ActionDTypeWrapper(dm_env.Environment):
|
||||
def __init__(self, env, dtype):
|
||||
self._env = env
|
||||
wrapped_action_spec = env.action_spec()
|
||||
self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
|
||||
dtype,
|
||||
wrapped_action_spec.minimum,
|
||||
wrapped_action_spec.maximum,
|
||||
'action')
|
||||
|
||||
def step(self, action):
|
||||
action = action.astype(self._env.action_spec().dtype)
|
||||
return self._env.step(action)
|
||||
|
||||
def observation_spec(self):
|
||||
return self._env.observation_spec()
|
||||
|
||||
def action_spec(self):
|
||||
return self._action_spec
|
||||
|
||||
def reset(self):
|
||||
return self._env.reset()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
class ExtendedTimeStepWrapper(dm_env.Environment):
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
|
||||
def reset(self):
|
||||
time_step = self._env.reset()
|
||||
return self._augment_time_step(time_step)
|
||||
|
||||
def step(self, action):
|
||||
time_step = self._env.step(action)
|
||||
return self._augment_time_step(time_step, action)
|
||||
|
||||
def _augment_time_step(self, time_step, action=None):
|
||||
if action is None:
|
||||
action_spec = self.action_spec()
|
||||
action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
|
||||
return ExtendedTimeStep(observation=time_step.observation,
|
||||
step_type=time_step.step_type,
|
||||
action=action,
|
||||
reward=time_step.reward or 0.0,
|
||||
discount=time_step.discount or 1.0)
|
||||
|
||||
def observation_spec(self):
|
||||
return self._env.observation_spec()
|
||||
|
||||
def action_spec(self):
|
||||
return self._env.action_spec()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._env, name)
|
||||
|
||||
|
||||
class TimeStepToGymWrapper:
|
||||
def __init__(self, env, domain, task):
|
||||
obs_shp = []
|
||||
for v in env.observation_spec().values():
|
||||
try:
|
||||
shp = np.prod(v.shape)
|
||||
except:
|
||||
shp = 1
|
||||
obs_shp.append(shp)
|
||||
obs_shp = (int(np.sum(obs_shp)),)
|
||||
act_shp = env.action_spec().shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=np.full(
|
||||
obs_shp,
|
||||
-np.inf,
|
||||
dtype=np.float32),
|
||||
high=np.full(
|
||||
obs_shp,
|
||||
np.inf,
|
||||
dtype=np.float32),
|
||||
dtype=np.float32,
|
||||
)
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=np.full(act_shp, env.action_spec().minimum),
|
||||
high=np.full(act_shp, env.action_spec().maximum),
|
||||
dtype=env.action_spec().dtype)
|
||||
class DMControlWrapper:
|
||||
def __init__(self, env, domain):
|
||||
self.env = env
|
||||
self.domain = domain
|
||||
self.task = task
|
||||
self.max_episode_steps = 500
|
||||
self.t = 0
|
||||
|
||||
self.camera_id = 2 if domain == 'quadruped' else 0
|
||||
obs_shape = get_obs_shape(env)
|
||||
action_shape = env.action_spec().shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=np.full(obs_shape, -np.inf, dtype=np.float32),
|
||||
high=np.full(obs_shape, np.inf, dtype=np.float32),
|
||||
dtype=np.float32)
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=np.full(action_shape, env.action_spec().minimum),
|
||||
high=np.full(action_shape, env.action_spec().maximum),
|
||||
dtype=env.action_spec().dtype)
|
||||
self.action_spec_dtype = env.action_spec().dtype
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
return self.env
|
||||
|
||||
@property
|
||||
def reward_range(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return None
|
||||
|
||||
def _obs_to_array(self, obs):
|
||||
return np.concatenate([v.flatten() for v in obs.values()])
|
||||
return torch.from_numpy(
|
||||
np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32))
|
||||
|
||||
def reset(self):
|
||||
return self._obs_to_array(self.env.reset().observation)
|
||||
|
||||
def step(self, action):
|
||||
reward = 0
|
||||
action = action.astype(self.action_spec_dtype)
|
||||
for _ in range(2):
|
||||
step = self.env.step(action)
|
||||
reward += step.reward
|
||||
return self._obs_to_array(step.observation), reward, False, defaultdict(float)
|
||||
|
||||
def render(self, width=384, height=384, camera_id=None):
|
||||
return self.env.physics.render(height, width, camera_id or self.camera_id)
|
||||
|
||||
|
||||
class Pixels(gym.Wrapper):
|
||||
def __init__(self, env, cfg, num_frames=3, size=64):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(num_frames*3, size, size), dtype=np.uint8)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._size = size
|
||||
|
||||
def _get_obs(self, is_reset=False):
|
||||
frame = self.env.render(width=self._size, height=self._size).transpose(2, 0, 1)
|
||||
num_frames = self._frames.maxlen if is_reset else 1
|
||||
for _ in range(num_frames):
|
||||
self._frames.append(frame)
|
||||
return torch.from_numpy(np.concatenate(self._frames))
|
||||
|
||||
def reset(self):
|
||||
self.t = 0
|
||||
return self._obs_to_array(self.env.reset().observation)
|
||||
|
||||
def step(self, action):
|
||||
self.t += 1
|
||||
time_step = self.env.step(action)
|
||||
return self._obs_to_array(time_step.observation), time_step.reward, time_step.last() or self.t == self.max_episode_steps, defaultdict(float)
|
||||
self.env.reset()
|
||||
return self._get_obs(is_reset=True)
|
||||
|
||||
def render(self, mode='rgb_array', width=384, height=384, camera_id=0):
|
||||
camera_id = dict(quadruped=2).get(self.domain, camera_id)
|
||||
return self.env.physics.render(height, width, camera_id)
|
||||
def step(self, action):
|
||||
_, reward, done, info = self.env.step(action)
|
||||
return self._get_obs(), reward, done, info
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
@@ -192,9 +102,9 @@ def make_env(cfg):
|
||||
task,
|
||||
task_kwargs={'random': cfg.seed},
|
||||
visualize_reward=False)
|
||||
env = ActionDTypeWrapper(env, np.float32)
|
||||
env = ActionRepeatWrapper(env, 2)
|
||||
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
|
||||
env = ExtendedTimeStepWrapper(env)
|
||||
env = TimeStepToGymWrapper(env, domain, task)
|
||||
env = DMControlWrapper(env, domain)
|
||||
if cfg.obs == 'rgb':
|
||||
env = Pixels(env, cfg)
|
||||
env = Timeout(env, max_episode_steps=500)
|
||||
return env
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from envs.wrappers.time_limit import TimeLimit
|
||||
from envs.wrappers.timeout import Timeout
|
||||
|
||||
import mani_skill2.envs
|
||||
|
||||
@@ -74,6 +74,6 @@ def make_env(cfg):
|
||||
render_camera_cfgs=dict(width=384, height=384),
|
||||
)
|
||||
env = ManiSkillWrapper(env, cfg)
|
||||
env = TimeLimit(env, max_episode_steps=100)
|
||||
env = Timeout(env, max_episode_steps=100)
|
||||
env.max_episode_steps = env._max_episode_steps
|
||||
return env
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from envs.wrappers.time_limit import TimeLimit
|
||||
from envs.wrappers.timeout import Timeout
|
||||
|
||||
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
|
||||
|
||||
@@ -47,6 +47,6 @@ def make_env(cfg):
|
||||
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
|
||||
env = MetaWorldWrapper(env, cfg)
|
||||
env = TimeLimit(env, max_episode_steps=100)
|
||||
env = Timeout(env, max_episode_steps=100)
|
||||
env.max_episode_steps = env._max_episode_steps
|
||||
return env
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from envs.wrappers.time_limit import TimeLimit
|
||||
import gymnasium as gym
|
||||
from envs.wrappers.timeout import Timeout
|
||||
|
||||
|
||||
MYOSUITE_TASKS = {
|
||||
@@ -53,6 +53,6 @@ def make_env(cfg):
|
||||
from myosuite.utils import gym as gym_utils
|
||||
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
|
||||
env = MyoSuiteWrapper(env, cfg)
|
||||
env = TimeLimit(env, max_episode_steps=100)
|
||||
env = Timeout(env, max_episode_steps=100)
|
||||
env.max_episode_steps = env._max_episode_steps
|
||||
return env
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from collections import deque
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class PixelWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for pixel observations. Compatible with DMControl environments.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env, num_frames=3, render_size=64):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8
|
||||
)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._render_size = render_size
|
||||
|
||||
def _get_obs(self):
|
||||
frame = self.env.render(
|
||||
mode='rgb_array', width=self._render_size, height=self._render_size
|
||||
).transpose(2, 0, 1)
|
||||
self._frames.append(frame)
|
||||
return torch.from_numpy(np.concatenate(self._frames))
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
for _ in range(self._frames.maxlen):
|
||||
obs = self._get_obs()
|
||||
return obs
|
||||
|
||||
def step(self, action):
|
||||
_, reward, done, info = self.env.step(action)
|
||||
return self._get_obs(), reward, done, info
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -17,9 +17,10 @@ class TensorWrapper(gym.Wrapper):
|
||||
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
||||
|
||||
def _try_f32_tensor(self, x):
|
||||
x = torch.from_numpy(x)
|
||||
if x.dtype == torch.float64:
|
||||
x = x.float()
|
||||
if isinstance(x, np.ndarray):
|
||||
x = torch.from_numpy(x)
|
||||
if x.dtype == torch.float64:
|
||||
x = x.float()
|
||||
return x
|
||||
|
||||
def _obs_to_tensor(self, obs):
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
Wrapper for limiting the time steps of an environment.
|
||||
Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import gym
|
||||
|
||||
|
||||
class TimeLimit(gym.Wrapper):
|
||||
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
|
||||
|
||||
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
|
||||
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
|
||||
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
|
||||
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
|
||||
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
|
||||
|
||||
Example:
|
||||
>>> from gym.envs.classic_control import CartPoleEnv
|
||||
>>> from gym.wrappers import TimeLimit
|
||||
>>> env = CartPoleEnv()
|
||||
>>> env = TimeLimit(env, max_episode_steps=1000)
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
|
||||
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
|
||||
"""
|
||||
super().__init__(env)
|
||||
if max_episode_steps is None and self.env.spec is not None:
|
||||
max_episode_steps = env.spec.max_episode_steps
|
||||
if self.env.spec is not None:
|
||||
self.env.spec.max_episode_steps = max_episode_steps
|
||||
self._max_episode_steps = max_episode_steps
|
||||
self._elapsed_steps = None
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
|
||||
|
||||
Args:
|
||||
action: The environment step action
|
||||
|
||||
Returns:
|
||||
The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
|
||||
when truncated (the number of steps elapsed >= max episode steps) or
|
||||
"TimeLimit.truncated"=False if the environment terminated
|
||||
"""
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
self._elapsed_steps += 1
|
||||
if self._elapsed_steps >= self._max_episode_steps:
|
||||
# TimeLimit.truncated key may have been already set by the environment
|
||||
# do not overwrite it
|
||||
episode_truncated = not done or info.get("TimeLimit.truncated", False)
|
||||
info["TimeLimit.truncated"] = episode_truncated
|
||||
done = True
|
||||
return observation, reward, done, info
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
|
||||
|
||||
Args:
|
||||
**kwargs: The kwargs to reset the environment with
|
||||
|
||||
Returns:
|
||||
The reset environment
|
||||
"""
|
||||
self._elapsed_steps = 0
|
||||
return self.env.reset(**kwargs)
|
||||
25
tdmpc2/envs/wrappers/timeout.py
Normal file
25
tdmpc2/envs/wrappers/timeout.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class Timeout(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for enforcing a time limit on the environment.
|
||||
"""
|
||||
|
||||
def __init__(self, env, max_episode_steps):
|
||||
super().__init__(env)
|
||||
self._max_episode_steps = max_episode_steps
|
||||
|
||||
@property
|
||||
def max_episode_steps(self):
|
||||
return self._max_episode_steps
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self._t = 0
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._t += 1
|
||||
done = done or self._t >= self.max_episode_steps
|
||||
return obs, reward, done, info
|
||||
Reference in New Issue
Block a user