diff --git a/docker/environment.yaml b/docker/environment.yaml index c527f0b..45e0389 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -13,10 +13,9 @@ dependencies: - pytorch-cuda=12.4 - torchvision=0.15.2 - pip: - - absl-py==2.1.0 - - "cython<3" - - dm-control==1.0.8 + - dm-control==1.0.16 - glfw==2.7.0 + - gymnasium==0.29.1 - ffmpeg==1.4 - imageio==2.34.1 - imageio-ffmpeg==0.4.9 @@ -24,12 +23,9 @@ dependencies: - hydra-core==1.3.2 - hydra-submitit-launcher==1.2.0 - submitit==1.5.1 - - setuptools==65.5.0 - - patchelf==0.17.2.1 - omegaconf==2.3.0 - moviepy==1.0.3 - - mujoco==2.3.1 - - mujoco-py==2.1.2.14 + - mujoco==3.1.2 - numpy==1.24.4 - tensordict-nightly==2024.11.14 - torchrl-nightly==2024.11.14 @@ -38,10 +34,14 @@ dependencies: - tqdm==4.66.4 - pandas==2.0.3 - wandb==0.17.4 - - wheel==0.38.0 #################### # Gym: # (unmaintained but required for maniskill2/meta-world) + # - "cython<3" + # - wheel==0.38.0 + # - setuptools==65.5.0 + # - mujoco==2.3.1 + # - mujoco-py==2.1.2.14 # - gym==0.21.0 #################### # ManiSkill2: diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 6326a9e..247697f 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -1,10 +1,9 @@ from copy import deepcopy import warnings -import gym +import gymnasium as gym from envs.wrappers.multitask import MultitaskWrapper -from envs.wrappers.pixels import PixelWrapper from envs.wrappers.tensor import TensorWrapper def missing_dependencies(task): @@ -70,8 +69,6 @@ def make_env(cfg): if env is None: raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') env = TensorWrapper(env) - if cfg.get('obs', 'state') == 'rgb': - env = PixelWrapper(cfg, env) try: # Dict cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} except: # Box diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index 97be75a..5df96a1 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -1,181 +1,91 @@ -from collections import deque, defaultdict -from typing import Any, NamedTuple -import dm_env +from collections import defaultdict, deque + +import gymnasium as gym import numpy as np +import torch + from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish from dm_control import suite suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom') suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS) from dm_control.suite.wrappers import action_scale -from dm_env import StepType, specs -import gym +from envs.wrappers.timeout import Timeout -class ExtendedTimeStep(NamedTuple): - step_type: Any - reward: Any - discount: Any - observation: Any - action: Any - - def first(self): - return self.step_type == StepType.FIRST - - def mid(self): - return self.step_type == StepType.MID - - def last(self): - return self.step_type == StepType.LAST +def get_obs_shape(env): + obs_shp = [] + for v in env.observation_spec().values(): + try: + shp = np.prod(v.shape) + except: + shp = 1 + obs_shp.append(shp) + return (int(np.sum(obs_shp)),) -class ActionRepeatWrapper(dm_env.Environment): - def __init__(self, env, num_repeats): - self._env = env - self._num_repeats = num_repeats - - def step(self, action): - reward = 0.0 - discount = 1.0 - for i in range(self._num_repeats): - time_step = self._env.step(action) - reward += (time_step.reward or 0.0) * discount - discount *= time_step.discount - if time_step.last(): - break - - return time_step._replace(reward=reward, discount=discount) - - def observation_spec(self): - return self._env.observation_spec() - - def action_spec(self): - return self._env.action_spec() - - def reset(self): - return self._env.reset() - - def __getattr__(self, name): - return getattr(self._env, name) - - -class ActionDTypeWrapper(dm_env.Environment): - def __init__(self, env, dtype): - self._env = env - wrapped_action_spec = env.action_spec() - self._action_spec = specs.BoundedArray(wrapped_action_spec.shape, - dtype, - wrapped_action_spec.minimum, - wrapped_action_spec.maximum, - 'action') - - def step(self, action): - action = action.astype(self._env.action_spec().dtype) - return self._env.step(action) - - def observation_spec(self): - return self._env.observation_spec() - - def action_spec(self): - return self._action_spec - - def reset(self): - return self._env.reset() - - def __getattr__(self, name): - return getattr(self._env, name) - - -class ExtendedTimeStepWrapper(dm_env.Environment): - def __init__(self, env): - self._env = env - - def reset(self): - time_step = self._env.reset() - return self._augment_time_step(time_step) - - def step(self, action): - time_step = self._env.step(action) - return self._augment_time_step(time_step, action) - - def _augment_time_step(self, time_step, action=None): - if action is None: - action_spec = self.action_spec() - action = np.zeros(action_spec.shape, dtype=action_spec.dtype) - return ExtendedTimeStep(observation=time_step.observation, - step_type=time_step.step_type, - action=action, - reward=time_step.reward or 0.0, - discount=time_step.discount or 1.0) - - def observation_spec(self): - return self._env.observation_spec() - - def action_spec(self): - return self._env.action_spec() - - def __getattr__(self, name): - return getattr(self._env, name) - - -class TimeStepToGymWrapper: - def __init__(self, env, domain, task): - obs_shp = [] - for v in env.observation_spec().values(): - try: - shp = np.prod(v.shape) - except: - shp = 1 - obs_shp.append(shp) - obs_shp = (int(np.sum(obs_shp)),) - act_shp = env.action_spec().shape - self.observation_space = gym.spaces.Box( - low=np.full( - obs_shp, - -np.inf, - dtype=np.float32), - high=np.full( - obs_shp, - np.inf, - dtype=np.float32), - dtype=np.float32, - ) - self.action_space = gym.spaces.Box( - low=np.full(act_shp, env.action_spec().minimum), - high=np.full(act_shp, env.action_spec().maximum), - dtype=env.action_spec().dtype) +class DMControlWrapper: + def __init__(self, env, domain): self.env = env - self.domain = domain - self.task = task - self.max_episode_steps = 500 - self.t = 0 - + self.camera_id = 2 if domain == 'quadruped' else 0 + obs_shape = get_obs_shape(env) + action_shape = env.action_spec().shape + self.observation_space = gym.spaces.Box( + low=np.full(obs_shape, -np.inf, dtype=np.float32), + high=np.full(obs_shape, np.inf, dtype=np.float32), + dtype=np.float32) + self.action_space = gym.spaces.Box( + low=np.full(action_shape, env.action_spec().minimum), + high=np.full(action_shape, env.action_spec().maximum), + dtype=env.action_spec().dtype) + self.action_spec_dtype = env.action_spec().dtype + @property def unwrapped(self): return self.env - - @property - def reward_range(self): - return None - - @property - def metadata(self): - return None def _obs_to_array(self, obs): - return np.concatenate([v.flatten() for v in obs.values()]) + return torch.from_numpy( + np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32)) + + def reset(self): + return self._obs_to_array(self.env.reset().observation) + + def step(self, action): + reward = 0 + action = action.astype(self.action_spec_dtype) + for _ in range(2): + step = self.env.step(action) + reward += step.reward + return self._obs_to_array(step.observation), reward, False, defaultdict(float) + + def render(self, width=384, height=384, camera_id=None): + return self.env.physics.render(height, width, camera_id or self.camera_id) + + +class Pixels(gym.Wrapper): + def __init__(self, env, cfg, num_frames=3, size=64): + super().__init__(env) + self.cfg = cfg + self.env = env + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=(num_frames*3, size, size), dtype=np.uint8) + self._frames = deque([], maxlen=num_frames) + self._size = size + + def _get_obs(self, is_reset=False): + frame = self.env.render(width=self._size, height=self._size).transpose(2, 0, 1) + num_frames = self._frames.maxlen if is_reset else 1 + for _ in range(num_frames): + self._frames.append(frame) + return torch.from_numpy(np.concatenate(self._frames)) def reset(self): - self.t = 0 - return self._obs_to_array(self.env.reset().observation) - - def step(self, action): - self.t += 1 - time_step = self.env.step(action) - return self._obs_to_array(time_step.observation), time_step.reward, time_step.last() or self.t == self.max_episode_steps, defaultdict(float) + self.env.reset() + return self._get_obs(is_reset=True) - def render(self, mode='rgb_array', width=384, height=384, camera_id=0): - camera_id = dict(quadruped=2).get(self.domain, camera_id) - return self.env.physics.render(height, width, camera_id) + def step(self, action): + _, reward, done, info = self.env.step(action) + return self._get_obs(), reward, done, info def make_env(cfg): @@ -192,9 +102,9 @@ def make_env(cfg): task, task_kwargs={'random': cfg.seed}, visualize_reward=False) - env = ActionDTypeWrapper(env, np.float32) - env = ActionRepeatWrapper(env, 2) env = action_scale.Wrapper(env, minimum=-1., maximum=1.) - env = ExtendedTimeStepWrapper(env) - env = TimeStepToGymWrapper(env, domain, task) + env = DMControlWrapper(env, domain) + if cfg.obs == 'rgb': + env = Pixels(env, cfg) + env = Timeout(env, max_episode_steps=500) return env diff --git a/tdmpc2/envs/maniskill.py b/tdmpc2/envs/maniskill.py index 7b0b6ed..cac45aa 100644 --- a/tdmpc2/envs/maniskill.py +++ b/tdmpc2/envs/maniskill.py @@ -1,6 +1,6 @@ -import gym +import gymnasium as gym import numpy as np -from envs.wrappers.time_limit import TimeLimit +from envs.wrappers.timeout import Timeout import mani_skill2.envs @@ -74,6 +74,6 @@ def make_env(cfg): render_camera_cfgs=dict(width=384, height=384), ) env = ManiSkillWrapper(env, cfg) - env = TimeLimit(env, max_episode_steps=100) + env = Timeout(env, max_episode_steps=100) env.max_episode_steps = env._max_episode_steps return env diff --git a/tdmpc2/envs/metaworld.py b/tdmpc2/envs/metaworld.py index f5f4f0d..f9b3513 100644 --- a/tdmpc2/envs/metaworld.py +++ b/tdmpc2/envs/metaworld.py @@ -1,6 +1,6 @@ import numpy as np import gym -from envs.wrappers.time_limit import TimeLimit +from envs.wrappers.timeout import Timeout from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE @@ -47,6 +47,6 @@ def make_env(cfg): assert cfg.obs == 'state', 'This task only supports state observations.' env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed) env = MetaWorldWrapper(env, cfg) - env = TimeLimit(env, max_episode_steps=100) + env = Timeout(env, max_episode_steps=100) env.max_episode_steps = env._max_episode_steps return env diff --git a/tdmpc2/envs/myosuite.py b/tdmpc2/envs/myosuite.py index d15f11f..ee35fd1 100644 --- a/tdmpc2/envs/myosuite.py +++ b/tdmpc2/envs/myosuite.py @@ -1,6 +1,6 @@ import numpy as np -import gym -from envs.wrappers.time_limit import TimeLimit +import gymnasium as gym +from envs.wrappers.timeout import Timeout MYOSUITE_TASKS = { @@ -53,6 +53,6 @@ def make_env(cfg): from myosuite.utils import gym as gym_utils env = gym_utils.make(MYOSUITE_TASKS[cfg.task]) env = MyoSuiteWrapper(env, cfg) - env = TimeLimit(env, max_episode_steps=100) + env = Timeout(env, max_episode_steps=100) env.max_episode_steps = env._max_episode_steps return env diff --git a/tdmpc2/envs/wrappers/multitask.py b/tdmpc2/envs/wrappers/multitask.py index 08dd4eb..529295a 100644 --- a/tdmpc2/envs/wrappers/multitask.py +++ b/tdmpc2/envs/wrappers/multitask.py @@ -1,4 +1,4 @@ -import gym +import gymnasium as gym import numpy as np import torch diff --git a/tdmpc2/envs/wrappers/pixels.py b/tdmpc2/envs/wrappers/pixels.py deleted file mode 100644 index c299875..0000000 --- a/tdmpc2/envs/wrappers/pixels.py +++ /dev/null @@ -1,38 +0,0 @@ -from collections import deque - -import gym -import numpy as np -import torch - - -class PixelWrapper(gym.Wrapper): - """ - Wrapper for pixel observations. Compatible with DMControl environments. - """ - - def __init__(self, cfg, env, num_frames=3, render_size=64): - super().__init__(env) - self.cfg = cfg - self.env = env - self.observation_space = gym.spaces.Box( - low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8 - ) - self._frames = deque([], maxlen=num_frames) - self._render_size = render_size - - def _get_obs(self): - frame = self.env.render( - mode='rgb_array', width=self._render_size, height=self._render_size - ).transpose(2, 0, 1) - self._frames.append(frame) - return torch.from_numpy(np.concatenate(self._frames)) - - def reset(self): - self.env.reset() - for _ in range(self._frames.maxlen): - obs = self._get_obs() - return obs - - def step(self, action): - _, reward, done, info = self.env.step(action) - return self._get_obs(), reward, done, info diff --git a/tdmpc2/envs/wrappers/tensor.py b/tdmpc2/envs/wrappers/tensor.py index 548a5f4..5054989 100644 --- a/tdmpc2/envs/wrappers/tensor.py +++ b/tdmpc2/envs/wrappers/tensor.py @@ -1,6 +1,6 @@ from collections import defaultdict -import gym +import gymnasium as gym import numpy as np import torch @@ -17,9 +17,10 @@ class TensorWrapper(gym.Wrapper): return torch.from_numpy(self.action_space.sample().astype(np.float32)) def _try_f32_tensor(self, x): - x = torch.from_numpy(x) - if x.dtype == torch.float64: - x = x.float() + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + if x.dtype == torch.float64: + x = x.float() return x def _obs_to_tensor(self, obs): diff --git a/tdmpc2/envs/wrappers/time_limit.py b/tdmpc2/envs/wrappers/time_limit.py deleted file mode 100644 index f81c281..0000000 --- a/tdmpc2/envs/wrappers/time_limit.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -Wrapper for limiting the time steps of an environment. -Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py -""" -from typing import Optional - -import gym - - -class TimeLimit(gym.Wrapper): - """This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded. - - Oftentimes, it is **very** important to distinguish `done` signals that were produced by the - :class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations). - This can be done by looking at the ``info`` that is returned when `done`-signal was issued. - The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if - the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``. - - Example: - >>> from gym.envs.classic_control import CartPoleEnv - >>> from gym.wrappers import TimeLimit - >>> env = CartPoleEnv() - >>> env = TimeLimit(env, max_episode_steps=1000) - """ - - def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None): - """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur. - - Args: - env: The environment to apply the wrapper - max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used) - """ - super().__init__(env) - if max_episode_steps is None and self.env.spec is not None: - max_episode_steps = env.spec.max_episode_steps - if self.env.spec is not None: - self.env.spec.max_episode_steps = max_episode_steps - self._max_episode_steps = max_episode_steps - self._elapsed_steps = None - - def step(self, action): - """Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate. - - Args: - action: The environment step action - - Returns: - The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True - when truncated (the number of steps elapsed >= max episode steps) or - "TimeLimit.truncated"=False if the environment terminated - """ - observation, reward, done, info = self.env.step(action) - self._elapsed_steps += 1 - if self._elapsed_steps >= self._max_episode_steps: - # TimeLimit.truncated key may have been already set by the environment - # do not overwrite it - episode_truncated = not done or info.get("TimeLimit.truncated", False) - info["TimeLimit.truncated"] = episode_truncated - done = True - return observation, reward, done, info - - def reset(self, **kwargs): - """Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero. - - Args: - **kwargs: The kwargs to reset the environment with - - Returns: - The reset environment - """ - self._elapsed_steps = 0 - return self.env.reset(**kwargs) diff --git a/tdmpc2/envs/wrappers/timeout.py b/tdmpc2/envs/wrappers/timeout.py new file mode 100644 index 0000000..cc2081c --- /dev/null +++ b/tdmpc2/envs/wrappers/timeout.py @@ -0,0 +1,25 @@ +import gymnasium as gym + + +class Timeout(gym.Wrapper): + """ + Wrapper for enforcing a time limit on the environment. + """ + + def __init__(self, env, max_episode_steps): + super().__init__(env) + self._max_episode_steps = max_episode_steps + + @property + def max_episode_steps(self): + return self._max_episode_steps + + def reset(self, **kwargs): + self._t = 0 + return self.env.reset(**kwargs) + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self._t += 1 + done = done or self._t >= self.max_episode_steps + return obs, reward, done, info