simplify dmcontrol wrappers + upgrade to gymnasium==0.29.1
This commit is contained in:
@@ -13,10 +13,9 @@ dependencies:
|
|||||||
- pytorch-cuda=12.4
|
- pytorch-cuda=12.4
|
||||||
- torchvision=0.15.2
|
- torchvision=0.15.2
|
||||||
- pip:
|
- pip:
|
||||||
- absl-py==2.1.0
|
- dm-control==1.0.16
|
||||||
- "cython<3"
|
|
||||||
- dm-control==1.0.8
|
|
||||||
- glfw==2.7.0
|
- glfw==2.7.0
|
||||||
|
- gymnasium==0.29.1
|
||||||
- ffmpeg==1.4
|
- ffmpeg==1.4
|
||||||
- imageio==2.34.1
|
- imageio==2.34.1
|
||||||
- imageio-ffmpeg==0.4.9
|
- imageio-ffmpeg==0.4.9
|
||||||
@@ -24,12 +23,9 @@ dependencies:
|
|||||||
- hydra-core==1.3.2
|
- hydra-core==1.3.2
|
||||||
- hydra-submitit-launcher==1.2.0
|
- hydra-submitit-launcher==1.2.0
|
||||||
- submitit==1.5.1
|
- submitit==1.5.1
|
||||||
- setuptools==65.5.0
|
|
||||||
- patchelf==0.17.2.1
|
|
||||||
- omegaconf==2.3.0
|
- omegaconf==2.3.0
|
||||||
- moviepy==1.0.3
|
- moviepy==1.0.3
|
||||||
- mujoco==2.3.1
|
- mujoco==3.1.2
|
||||||
- mujoco-py==2.1.2.14
|
|
||||||
- numpy==1.24.4
|
- numpy==1.24.4
|
||||||
- tensordict-nightly==2024.11.14
|
- tensordict-nightly==2024.11.14
|
||||||
- torchrl-nightly==2024.11.14
|
- torchrl-nightly==2024.11.14
|
||||||
@@ -38,10 +34,14 @@ dependencies:
|
|||||||
- tqdm==4.66.4
|
- tqdm==4.66.4
|
||||||
- pandas==2.0.3
|
- pandas==2.0.3
|
||||||
- wandb==0.17.4
|
- wandb==0.17.4
|
||||||
- wheel==0.38.0
|
|
||||||
####################
|
####################
|
||||||
# Gym:
|
# Gym:
|
||||||
# (unmaintained but required for maniskill2/meta-world)
|
# (unmaintained but required for maniskill2/meta-world)
|
||||||
|
# - "cython<3"
|
||||||
|
# - wheel==0.38.0
|
||||||
|
# - setuptools==65.5.0
|
||||||
|
# - mujoco==2.3.1
|
||||||
|
# - mujoco-py==2.1.2.14
|
||||||
# - gym==0.21.0
|
# - gym==0.21.0
|
||||||
####################
|
####################
|
||||||
# ManiSkill2:
|
# ManiSkill2:
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from envs.wrappers.multitask import MultitaskWrapper
|
from envs.wrappers.multitask import MultitaskWrapper
|
||||||
from envs.wrappers.pixels import PixelWrapper
|
|
||||||
from envs.wrappers.tensor import TensorWrapper
|
from envs.wrappers.tensor import TensorWrapper
|
||||||
|
|
||||||
def missing_dependencies(task):
|
def missing_dependencies(task):
|
||||||
@@ -70,8 +69,6 @@ def make_env(cfg):
|
|||||||
if env is None:
|
if env is None:
|
||||||
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
||||||
env = TensorWrapper(env)
|
env = TensorWrapper(env)
|
||||||
if cfg.get('obs', 'state') == 'rgb':
|
|
||||||
env = PixelWrapper(cfg, env)
|
|
||||||
try: # Dict
|
try: # Dict
|
||||||
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
||||||
except: # Box
|
except: # Box
|
||||||
|
|||||||
@@ -1,124 +1,18 @@
|
|||||||
from collections import deque, defaultdict
|
from collections import defaultdict, deque
|
||||||
from typing import Any, NamedTuple
|
|
||||||
import dm_env
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
|
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
|
||||||
from dm_control import suite
|
from dm_control import suite
|
||||||
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
||||||
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
||||||
from dm_control.suite.wrappers import action_scale
|
from dm_control.suite.wrappers import action_scale
|
||||||
from dm_env import StepType, specs
|
from envs.wrappers.timeout import Timeout
|
||||||
import gym
|
|
||||||
|
|
||||||
|
|
||||||
class ExtendedTimeStep(NamedTuple):
|
def get_obs_shape(env):
|
||||||
step_type: Any
|
|
||||||
reward: Any
|
|
||||||
discount: Any
|
|
||||||
observation: Any
|
|
||||||
action: Any
|
|
||||||
|
|
||||||
def first(self):
|
|
||||||
return self.step_type == StepType.FIRST
|
|
||||||
|
|
||||||
def mid(self):
|
|
||||||
return self.step_type == StepType.MID
|
|
||||||
|
|
||||||
def last(self):
|
|
||||||
return self.step_type == StepType.LAST
|
|
||||||
|
|
||||||
|
|
||||||
class ActionRepeatWrapper(dm_env.Environment):
|
|
||||||
def __init__(self, env, num_repeats):
|
|
||||||
self._env = env
|
|
||||||
self._num_repeats = num_repeats
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
reward = 0.0
|
|
||||||
discount = 1.0
|
|
||||||
for i in range(self._num_repeats):
|
|
||||||
time_step = self._env.step(action)
|
|
||||||
reward += (time_step.reward or 0.0) * discount
|
|
||||||
discount *= time_step.discount
|
|
||||||
if time_step.last():
|
|
||||||
break
|
|
||||||
|
|
||||||
return time_step._replace(reward=reward, discount=discount)
|
|
||||||
|
|
||||||
def observation_spec(self):
|
|
||||||
return self._env.observation_spec()
|
|
||||||
|
|
||||||
def action_spec(self):
|
|
||||||
return self._env.action_spec()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return self._env.reset()
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
|
|
||||||
class ActionDTypeWrapper(dm_env.Environment):
|
|
||||||
def __init__(self, env, dtype):
|
|
||||||
self._env = env
|
|
||||||
wrapped_action_spec = env.action_spec()
|
|
||||||
self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
|
|
||||||
dtype,
|
|
||||||
wrapped_action_spec.minimum,
|
|
||||||
wrapped_action_spec.maximum,
|
|
||||||
'action')
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
action = action.astype(self._env.action_spec().dtype)
|
|
||||||
return self._env.step(action)
|
|
||||||
|
|
||||||
def observation_spec(self):
|
|
||||||
return self._env.observation_spec()
|
|
||||||
|
|
||||||
def action_spec(self):
|
|
||||||
return self._action_spec
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return self._env.reset()
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
|
|
||||||
class ExtendedTimeStepWrapper(dm_env.Environment):
|
|
||||||
def __init__(self, env):
|
|
||||||
self._env = env
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
time_step = self._env.reset()
|
|
||||||
return self._augment_time_step(time_step)
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
time_step = self._env.step(action)
|
|
||||||
return self._augment_time_step(time_step, action)
|
|
||||||
|
|
||||||
def _augment_time_step(self, time_step, action=None):
|
|
||||||
if action is None:
|
|
||||||
action_spec = self.action_spec()
|
|
||||||
action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
|
|
||||||
return ExtendedTimeStep(observation=time_step.observation,
|
|
||||||
step_type=time_step.step_type,
|
|
||||||
action=action,
|
|
||||||
reward=time_step.reward or 0.0,
|
|
||||||
discount=time_step.discount or 1.0)
|
|
||||||
|
|
||||||
def observation_spec(self):
|
|
||||||
return self._env.observation_spec()
|
|
||||||
|
|
||||||
def action_spec(self):
|
|
||||||
return self._env.action_spec()
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
|
|
||||||
class TimeStepToGymWrapper:
|
|
||||||
def __init__(self, env, domain, task):
|
|
||||||
obs_shp = []
|
obs_shp = []
|
||||||
for v in env.observation_spec().values():
|
for v in env.observation_spec().values():
|
||||||
try:
|
try:
|
||||||
@@ -126,56 +20,72 @@ class TimeStepToGymWrapper:
|
|||||||
except:
|
except:
|
||||||
shp = 1
|
shp = 1
|
||||||
obs_shp.append(shp)
|
obs_shp.append(shp)
|
||||||
obs_shp = (int(np.sum(obs_shp)),)
|
return (int(np.sum(obs_shp)),)
|
||||||
act_shp = env.action_spec().shape
|
|
||||||
self.observation_space = gym.spaces.Box(
|
|
||||||
low=np.full(
|
class DMControlWrapper:
|
||||||
obs_shp,
|
def __init__(self, env, domain):
|
||||||
-np.inf,
|
|
||||||
dtype=np.float32),
|
|
||||||
high=np.full(
|
|
||||||
obs_shp,
|
|
||||||
np.inf,
|
|
||||||
dtype=np.float32),
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
self.action_space = gym.spaces.Box(
|
|
||||||
low=np.full(act_shp, env.action_spec().minimum),
|
|
||||||
high=np.full(act_shp, env.action_spec().maximum),
|
|
||||||
dtype=env.action_spec().dtype)
|
|
||||||
self.env = env
|
self.env = env
|
||||||
self.domain = domain
|
self.camera_id = 2 if domain == 'quadruped' else 0
|
||||||
self.task = task
|
obs_shape = get_obs_shape(env)
|
||||||
self.max_episode_steps = 500
|
action_shape = env.action_spec().shape
|
||||||
self.t = 0
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=np.full(obs_shape, -np.inf, dtype=np.float32),
|
||||||
|
high=np.full(obs_shape, np.inf, dtype=np.float32),
|
||||||
|
dtype=np.float32)
|
||||||
|
self.action_space = gym.spaces.Box(
|
||||||
|
low=np.full(action_shape, env.action_spec().minimum),
|
||||||
|
high=np.full(action_shape, env.action_spec().maximum),
|
||||||
|
dtype=env.action_spec().dtype)
|
||||||
|
self.action_spec_dtype = env.action_spec().dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self):
|
def unwrapped(self):
|
||||||
return self.env
|
return self.env
|
||||||
|
|
||||||
@property
|
|
||||||
def reward_range(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def metadata(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _obs_to_array(self, obs):
|
def _obs_to_array(self, obs):
|
||||||
return np.concatenate([v.flatten() for v in obs.values()])
|
return torch.from_numpy(
|
||||||
|
np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32))
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.t = 0
|
|
||||||
return self._obs_to_array(self.env.reset().observation)
|
return self._obs_to_array(self.env.reset().observation)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.t += 1
|
reward = 0
|
||||||
time_step = self.env.step(action)
|
action = action.astype(self.action_spec_dtype)
|
||||||
return self._obs_to_array(time_step.observation), time_step.reward, time_step.last() or self.t == self.max_episode_steps, defaultdict(float)
|
for _ in range(2):
|
||||||
|
step = self.env.step(action)
|
||||||
|
reward += step.reward
|
||||||
|
return self._obs_to_array(step.observation), reward, False, defaultdict(float)
|
||||||
|
|
||||||
def render(self, mode='rgb_array', width=384, height=384, camera_id=0):
|
def render(self, width=384, height=384, camera_id=None):
|
||||||
camera_id = dict(quadruped=2).get(self.domain, camera_id)
|
return self.env.physics.render(height, width, camera_id or self.camera_id)
|
||||||
return self.env.physics.render(height, width, camera_id)
|
|
||||||
|
|
||||||
|
class Pixels(gym.Wrapper):
|
||||||
|
def __init__(self, env, cfg, num_frames=3, size=64):
|
||||||
|
super().__init__(env)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.env = env
|
||||||
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=0, high=255, shape=(num_frames*3, size, size), dtype=np.uint8)
|
||||||
|
self._frames = deque([], maxlen=num_frames)
|
||||||
|
self._size = size
|
||||||
|
|
||||||
|
def _get_obs(self, is_reset=False):
|
||||||
|
frame = self.env.render(width=self._size, height=self._size).transpose(2, 0, 1)
|
||||||
|
num_frames = self._frames.maxlen if is_reset else 1
|
||||||
|
for _ in range(num_frames):
|
||||||
|
self._frames.append(frame)
|
||||||
|
return torch.from_numpy(np.concatenate(self._frames))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.env.reset()
|
||||||
|
return self._get_obs(is_reset=True)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
_, reward, done, info = self.env.step(action)
|
||||||
|
return self._get_obs(), reward, done, info
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg):
|
def make_env(cfg):
|
||||||
@@ -192,9 +102,9 @@ def make_env(cfg):
|
|||||||
task,
|
task,
|
||||||
task_kwargs={'random': cfg.seed},
|
task_kwargs={'random': cfg.seed},
|
||||||
visualize_reward=False)
|
visualize_reward=False)
|
||||||
env = ActionDTypeWrapper(env, np.float32)
|
|
||||||
env = ActionRepeatWrapper(env, 2)
|
|
||||||
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
|
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
|
||||||
env = ExtendedTimeStepWrapper(env)
|
env = DMControlWrapper(env, domain)
|
||||||
env = TimeStepToGymWrapper(env, domain, task)
|
if cfg.obs == 'rgb':
|
||||||
|
env = Pixels(env, cfg)
|
||||||
|
env = Timeout(env, max_episode_steps=500)
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from envs.wrappers.time_limit import TimeLimit
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
import mani_skill2.envs
|
import mani_skill2.envs
|
||||||
|
|
||||||
@@ -74,6 +74,6 @@ def make_env(cfg):
|
|||||||
render_camera_cfgs=dict(width=384, height=384),
|
render_camera_cfgs=dict(width=384, height=384),
|
||||||
)
|
)
|
||||||
env = ManiSkillWrapper(env, cfg)
|
env = ManiSkillWrapper(env, cfg)
|
||||||
env = TimeLimit(env, max_episode_steps=100)
|
env = Timeout(env, max_episode_steps=100)
|
||||||
env.max_episode_steps = env._max_episode_steps
|
env.max_episode_steps = env._max_episode_steps
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
from envs.wrappers.time_limit import TimeLimit
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
|
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
|
||||||
|
|
||||||
@@ -47,6 +47,6 @@ def make_env(cfg):
|
|||||||
assert cfg.obs == 'state', 'This task only supports state observations.'
|
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||||
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
|
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
|
||||||
env = MetaWorldWrapper(env, cfg)
|
env = MetaWorldWrapper(env, cfg)
|
||||||
env = TimeLimit(env, max_episode_steps=100)
|
env = Timeout(env, max_episode_steps=100)
|
||||||
env.max_episode_steps = env._max_episode_steps
|
env.max_episode_steps = env._max_episode_steps
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gymnasium as gym
|
||||||
from envs.wrappers.time_limit import TimeLimit
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
|
|
||||||
MYOSUITE_TASKS = {
|
MYOSUITE_TASKS = {
|
||||||
@@ -53,6 +53,6 @@ def make_env(cfg):
|
|||||||
from myosuite.utils import gym as gym_utils
|
from myosuite.utils import gym as gym_utils
|
||||||
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
|
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
|
||||||
env = MyoSuiteWrapper(env, cfg)
|
env = MyoSuiteWrapper(env, cfg)
|
||||||
env = TimeLimit(env, max_episode_steps=100)
|
env = Timeout(env, max_episode_steps=100)
|
||||||
env.max_episode_steps = env._max_episode_steps
|
env.max_episode_steps = env._max_episode_steps
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
from collections import deque
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class PixelWrapper(gym.Wrapper):
|
|
||||||
"""
|
|
||||||
Wrapper for pixel observations. Compatible with DMControl environments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cfg, env, num_frames=3, render_size=64):
|
|
||||||
super().__init__(env)
|
|
||||||
self.cfg = cfg
|
|
||||||
self.env = env
|
|
||||||
self.observation_space = gym.spaces.Box(
|
|
||||||
low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8
|
|
||||||
)
|
|
||||||
self._frames = deque([], maxlen=num_frames)
|
|
||||||
self._render_size = render_size
|
|
||||||
|
|
||||||
def _get_obs(self):
|
|
||||||
frame = self.env.render(
|
|
||||||
mode='rgb_array', width=self._render_size, height=self._render_size
|
|
||||||
).transpose(2, 0, 1)
|
|
||||||
self._frames.append(frame)
|
|
||||||
return torch.from_numpy(np.concatenate(self._frames))
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.env.reset()
|
|
||||||
for _ in range(self._frames.maxlen):
|
|
||||||
obs = self._get_obs()
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
_, reward, done, info = self.env.step(action)
|
|
||||||
return self._get_obs(), reward, done, info
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -17,6 +17,7 @@ class TensorWrapper(gym.Wrapper):
|
|||||||
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
||||||
|
|
||||||
def _try_f32_tensor(self, x):
|
def _try_f32_tensor(self, x):
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
x = torch.from_numpy(x)
|
x = torch.from_numpy(x)
|
||||||
if x.dtype == torch.float64:
|
if x.dtype == torch.float64:
|
||||||
x = x.float()
|
x = x.float()
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
"""
|
|
||||||
Wrapper for limiting the time steps of an environment.
|
|
||||||
Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py
|
|
||||||
"""
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import gym
|
|
||||||
|
|
||||||
|
|
||||||
class TimeLimit(gym.Wrapper):
|
|
||||||
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
|
|
||||||
|
|
||||||
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
|
|
||||||
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
|
|
||||||
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
|
|
||||||
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
|
|
||||||
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from gym.envs.classic_control import CartPoleEnv
|
|
||||||
>>> from gym.wrappers import TimeLimit
|
|
||||||
>>> env = CartPoleEnv()
|
|
||||||
>>> env = TimeLimit(env, max_episode_steps=1000)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
|
|
||||||
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env: The environment to apply the wrapper
|
|
||||||
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
|
|
||||||
"""
|
|
||||||
super().__init__(env)
|
|
||||||
if max_episode_steps is None and self.env.spec is not None:
|
|
||||||
max_episode_steps = env.spec.max_episode_steps
|
|
||||||
if self.env.spec is not None:
|
|
||||||
self.env.spec.max_episode_steps = max_episode_steps
|
|
||||||
self._max_episode_steps = max_episode_steps
|
|
||||||
self._elapsed_steps = None
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: The environment step action
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
|
|
||||||
when truncated (the number of steps elapsed >= max episode steps) or
|
|
||||||
"TimeLimit.truncated"=False if the environment terminated
|
|
||||||
"""
|
|
||||||
observation, reward, done, info = self.env.step(action)
|
|
||||||
self._elapsed_steps += 1
|
|
||||||
if self._elapsed_steps >= self._max_episode_steps:
|
|
||||||
# TimeLimit.truncated key may have been already set by the environment
|
|
||||||
# do not overwrite it
|
|
||||||
episode_truncated = not done or info.get("TimeLimit.truncated", False)
|
|
||||||
info["TimeLimit.truncated"] = episode_truncated
|
|
||||||
done = True
|
|
||||||
return observation, reward, done, info
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
|
||||||
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: The kwargs to reset the environment with
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The reset environment
|
|
||||||
"""
|
|
||||||
self._elapsed_steps = 0
|
|
||||||
return self.env.reset(**kwargs)
|
|
||||||
25
tdmpc2/envs/wrappers/timeout.py
Normal file
25
tdmpc2/envs/wrappers/timeout.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
|
class Timeout(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Wrapper for enforcing a time limit on the environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env, max_episode_steps):
|
||||||
|
super().__init__(env)
|
||||||
|
self._max_episode_steps = max_episode_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_episode_steps(self):
|
||||||
|
return self._max_episode_steps
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
self._t = 0
|
||||||
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
self._t += 1
|
||||||
|
done = done or self._t >= self.max_episode_steps
|
||||||
|
return obs, reward, done, info
|
||||||
Reference in New Issue
Block a user