diff --git a/README.md b/README.md
index a3ee4f1..94534a4 100755
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.

-This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. We hope that this repository will serve as a useful community resource for future research on model-based RL.
+This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
----
@@ -32,12 +32,15 @@ We provide a `Dockerfile` for easy installation. You can build the docker image
cd docker && docker build . -t /tdmpc2:0.1.0
```
-If you prefer to install dependencies manually, start by installing dependencies via `conda` by running
+If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:
```
conda env create -f docker/environment.yaml
+conda env create -f docker/environment_minimal.yaml
```
+The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.
+
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
```
@@ -72,11 +75,13 @@ This codebase currently supports **104** continuous control tasks from **DMContr
| metaworld | mw-pick-place-wall
| maniskill | pick-cube
| maniskill | pick-ycb
-| myosuite | myo-hand-key-turn
-| myosuite | myo-hand-key-turn-hard
+| myosuite | myo-key-turn
+| myosuite | myo-key-turn-hard
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.
+**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.
+
## Example usage
@@ -102,6 +107,7 @@ See below examples on how to train TD-MPC**2** on a single task (online RL) and
$ python train.py task=mt80 model_size=48 batch_size=1024
$ python train.py task=mt30 model_size=317 batch_size=1024
$ python train.py task=dog-run steps=7000000
+$ python train.py task=walker-walk obs=rgb
```
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
diff --git a/docker/environment.yaml b/docker/environment.yaml
index 18a9914..6792839 100644
--- a/docker/environment.yaml
+++ b/docker/environment.yaml
@@ -26,6 +26,7 @@ dependencies:
- hydra-core
- hydra-submitit-launcher
- submitit
+ - pandas
- patchelf
- protobuf
- tqdm
diff --git a/docker/environment_minimal.yaml b/docker/environment_minimal.yaml
new file mode 100644
index 0000000..fbe30f6
--- /dev/null
+++ b/docker/environment_minimal.yaml
@@ -0,0 +1,39 @@
+name: tdmpc2
+channels:
+ - pytorch-nightly
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.9.0
+ - pytorch
+ - torchvision
+ - cudatoolkit=11.7
+ - glew
+ - glib
+ - pip==21
+ - pip:
+ - absl-py
+ - glfw
+ - kornia
+ - termcolor
+ - gym==0.21.0
+ - moviepy
+ - ffmpeg
+ - imageio
+ - imageio-ffmpeg
+ - omegaconf
+ - hydra-core
+ - hydra-submitit-launcher
+ - submitit
+ - pandas
+ - patchelf
+ - protobuf
+ - tqdm
+ - setuptools==65.5.0
+ - "cython<3"
+ - dm-control
+ - pillow
+ - tensordict-nightly
+ - torchrl-nightly
+ - wandb
diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py
index dbbfea6..29cc293 100644
--- a/tdmpc2/common/buffer.py
+++ b/tdmpc2/common/buffer.py
@@ -1,43 +1,27 @@
-from pathlib import Path
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
-from torchrl.data.replay_buffers.samplers import RandomSampler
-from torchrl.envs import RandomCropTensorDict, Transform, Compose
-from common.logger import make_dir
-
-
-class DataPrepTransform(Transform):
- """
- Preprocesses data for TD-MPC2 training.
- Replay data is expected to be a TensorDict with the following keys:
- obs: observations
- action: actions
- reward: rewards
- task: task IDs (optional)
- A TensorDict with T time steps has T+1 observations and T actions and rewards.
- The first actions and rewards in each TensorDict are dummies and should be ignored.
- """
-
- def __init__(self):
- super().__init__([])
-
- def forward(self, td):
- td = td.permute(1,0)
- return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
+from common.samplers import SliceSampler
class Buffer():
"""
- Create a replay buffer for TD-MPC2 training.
+ Replay buffer for TD-MPC2 training. Based on torchrl.
Uses CUDA memory if available, and CPU memory otherwise.
"""
def __init__(self, cfg):
self.cfg = cfg
self._device = torch.device('cuda')
- self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
+ self._capacity = min(cfg.buffer_size, cfg.steps)
+ self._sampler = SliceSampler(
+ num_slices=self.cfg.batch_size,
+ end_key=None,
+ traj_key='episode',
+ truncated_key=None,
+ )
+ self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0
@property
@@ -53,63 +37,60 @@ class Buffer():
def _reserve_buffer(self, storage):
"""
Reserve a buffer with the given storage.
- Uses the RandomSampler to sample trajectories,
- and the RandomCropTensorDict transform to crop trajectories to the desired length.
- DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
"""
return ReplayBuffer(
storage=storage,
- sampler=RandomSampler(),
+ sampler=self._sampler,
pin_memory=True,
prefetch=1,
- transform=Compose(
- RandomCropTensorDict(self.cfg.horizon+1, -1),
- DataPrepTransform(),
- ),
- batch_size=self.cfg.batch_size,
+ batch_size=self._batch_size,
)
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
+ print(f'Buffer capacity: {self._capacity:,}')
mem_free, _ = torch.cuda.mem_get_info()
- bytes_per_ep = sum([
+ bytes_per_step = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \
- for k,v in tds.items()
- ])
- print(f'Bytes per episode: {bytes_per_ep:,}')
- total_bytes = bytes_per_ep*self._capacity
+ for v in tds.values()
+ ]) / len(tds)
+ total_bytes = bytes_per_step*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
- if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
- print('Using CPU memory for storage.')
- return self._reserve_buffer(
- LazyTensorStorage(self._capacity, device=torch.device('cpu'))
- )
- else: # Sufficient CUDA memory
- print('Using CUDA memory for storage.')
- return self._reserve_buffer(
- LazyTensorStorage(self._capacity, device=torch.device('cuda'))
- )
+ storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
+ print(f'Using {storage_device.upper()} memory for storage.')
+ return self._reserve_buffer(
+ LazyTensorStorage(self._capacity, device=torch.device(storage_device))
+ )
- def add(self, tds):
- """Add an episode to the buffer. All episodes are expected to have the same length."""
+ def _to_device(self, *args, device=None):
+ if device is None:
+ device = self._device
+ return (arg.to(device, non_blocking=True) \
+ if arg is not None else None for arg in args)
+
+ def _prepare_batch(self, td):
+ """
+ Prepare a sampled batch for training (post-processing).
+ Expects `td` to be a TensorDict with batch size TxB.
+ """
+ obs = td['obs']
+ action = td['action'][1:]
+ reward = td['reward'][1:].unsqueeze(-1)
+ task = td['task'][0] if 'task' in td.keys() else None
+ return self._to_device(obs, action, reward, task)
+
+ def add(self, td):
+ """Add an episode to the buffer."""
+ td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
if self._num_eps == 0:
- self._buffer = self._init(tds)
- self._buffer.add(tds)
+ self._buffer = self._init(td)
+ self._buffer.extend(td)
self._num_eps += 1
return self._num_eps
def sample(self):
- """Sample a batch of sub-trajectories from the buffer."""
- obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
- return obs.to(self._device, non_blocking=True), \
- action.to(self._device, non_blocking=True), \
- reward.to(self._device, non_blocking=True), \
- task.to(self._device, non_blocking=True) if task is not None else None
-
- def save(self):
- """Save the buffer to disk. Useful for storing offline datasets."""
- td = self._buffer._storage._storage.cpu()
- fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
- torch.save(td, fp)
+ """Sample a batch of subsequences from the buffer."""
+ td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
+ return self._prepare_batch(td)
diff --git a/tdmpc2/common/samplers.py b/tdmpc2/common/samplers.py
new file mode 100644
index 0000000..71cf70c
--- /dev/null
+++ b/tdmpc2/common/samplers.py
@@ -0,0 +1,351 @@
+from __future__ import annotations
+
+from copy import copy
+from typing import Tuple
+
+import torch
+
+from tensordict.utils import NestedKey
+
+from torchrl.data.replay_buffers.storages import Storage, TensorStorage
+from torchrl.data.replay_buffers.samplers import Sampler
+
+
+# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
+# This copy will live here until it has been included in a few torchrl stable releases
+
+
+class SliceSampler(Sampler):
+ """Samples slices of data along the first dimension, given start and stop signals.
+
+ This class samples sub-trajectories with replacement. For a version without
+ replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
+
+ Keyword Args:
+ num_slices (int): the number of slices to be sampled. The batch-size
+ must be greater or equal to the ``num_slices`` argument. Exclusive
+ with ``slice_len``.
+ slice_len (int): the length of the slices to be sampled. The batch-size
+ must be greater or equal to the ``slice_len`` argument and divisible
+ by it. Exclusive with ``num_slices``.
+ end_key (NestedKey, optional): the key indicating the end of a
+ trajectory (or episode). Defaults to ``("next", "done")``.
+ traj_key (NestedKey, optional): the key indicating the trajectories.
+ Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
+ cache_values (bool, optional): to be used with static datasets.
+ Will cache the start and end signal of the trajectory.
+ truncated_key (NestedKey, optional): If not ``None``, this argument
+ indicates where a truncated signal should be written in the output
+ data. This is used to indicate to value estimators where the provided
+ trajectory breaks. Defaults to ``("next", "truncated")``.
+ This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
+ instances (otherwise the truncated key is returned in the info dictionary
+ returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
+ strict_length (bool, optional): if ``False``, trajectories of length
+ shorter than `slice_len` (or `batch_size // num_slices`) will be
+ allowed to appear in the batch.
+ Be mindful that this can result in effective `batch_size` shorter
+ than the one asked for! Trajectories can be split using
+ :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
+
+ .. note:: To recover the trajectory splits in the storage,
+ :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
+ attempt to find the ``traj_key`` entry in the storage. If it cannot be
+ found, the ``end_key`` will be used to reconstruct the episodes.
+
+ Examples:
+ >>> import torch
+ >>> from tensordict import TensorDict
+ >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
+ >>> from torchrl.data.replay_buffers.samplers import SliceSampler
+ >>> torch.manual_seed(0)
+ >>> rb = TensorDictReplayBuffer(
+ ... storage=LazyMemmapStorage(1_000_000),
+ ... sampler=SliceSampler(cache_values=True, num_slices=10),
+ ... batch_size=320,
+ ... )
+ >>> episode = torch.zeros(1000, dtype=torch.int)
+ >>> episode[:300] = 1
+ >>> episode[300:550] = 2
+ >>> episode[550:700] = 3
+ >>> episode[700:] = 4
+ >>> data = TensorDict(
+ ... {
+ ... "episode": episode,
+ ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
+ ... "act": torch.randn((20,)).expand(1000, 20),
+ ... "other": torch.randn((20, 50)).expand(1000, 20, 50),
+ ... }, [1000]
+ ... )
+ >>> rb.extend(data)
+ >>> sample = rb.sample()
+ >>> print("sample:", sample)
+ >>> print("episodes", sample.get("episode").unique())
+ episodes tensor([1, 2, 3, 4], dtype=torch.int32)
+
+ :class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
+ most of TorchRL's datasets:
+
+ Examples:
+ >>> import torch
+ >>>
+ >>> from torchrl.data.datasets import RobosetExperienceReplay
+ >>> from torchrl.data import SliceSampler
+ >>>
+ >>> torch.manual_seed(0)
+ >>> num_slices = 10
+ >>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
+ >>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
+ >>> for batch in data:
+ ... batch = batch.reshape(num_slices, -1)
+ ... break
+ >>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
+ check that each batch only has one episode: tensor([[19],
+ [14],
+ [ 8],
+ [10],
+ [13],
+ [ 4],
+ [ 2],
+ [ 3],
+ [22],
+ [ 8]])
+
+ """
+
+ def __init__(
+ self,
+ *,
+ num_slices: int = None,
+ slice_len: int = None,
+ end_key: NestedKey | None = None,
+ traj_key: NestedKey | None = None,
+ cache_values: bool = False,
+ truncated_key: NestedKey | None = ("next", "truncated"),
+ strict_length: bool = True,
+ ) -> object:
+ if end_key is None:
+ end_key = ("next", "done")
+ if traj_key is None:
+ traj_key = "episode"
+ if not ((num_slices is None) ^ (slice_len is None)):
+ raise TypeError(
+ "Either num_slices or slice_len must be not None, and not both. "
+ f"Got num_slices={num_slices} and slice_len={slice_len}."
+ )
+ self.num_slices = num_slices
+ self.slice_len = slice_len
+ self.end_key = end_key
+ self.traj_key = traj_key
+ self.truncated_key = truncated_key
+ self.cache_values = cache_values
+ self._fetch_traj = True
+ self._uses_data_prefix = False
+ self.strict_length = strict_length
+ self._cache = {}
+
+ @staticmethod
+ def _find_start_stop_traj(*, trajectory=None, end=None):
+ if trajectory is not None:
+ # slower
+ # _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
+ # stop_idx = stop_idx.cumsum(0) - 1
+
+ # even slower
+ # t = trajectory.unsqueeze(0)
+ # w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
+ # stop_idx = torch.conv1d(t, w).nonzero()
+
+ # faster
+ end = trajectory[:-1] != trajectory[1:]
+ end = torch.cat([end, torch.ones_like(end[:1])], 0)
+ else:
+ end = torch.index_fill(
+ end,
+ index=torch.tensor(-1, device=end.device, dtype=torch.long),
+ dim=0,
+ value=1,
+ )
+ if end.ndim != 1:
+ raise RuntimeError(
+ f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
+ )
+ stop_idx = end.view(-1).nonzero().view(-1)
+ start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
+ lengths = stop_idx - start_idx + 1
+ return start_idx, stop_idx, lengths
+
+ def _tensor_slices_from_startend(self, seq_length, start):
+ if isinstance(seq_length, int):
+ return (
+ torch.arange(
+ seq_length, device=start.device, dtype=start.dtype
+ ).unsqueeze(0)
+ + start.unsqueeze(1)
+ ).view(-1)
+ else:
+ # when padding is needed
+ return torch.cat(
+ [
+ _start
+ + torch.arange(_seq_len, device=start.device, dtype=start.dtype)
+ for _start, _seq_len in zip(start, seq_length)
+ ]
+ )
+
+ def _get_stop_and_length(self, storage, fallback=True):
+ if self.cache_values and "stop-and-length" in self._cache:
+ return self._cache.get("stop-and-length")
+
+ if self._fetch_traj:
+ # We first try with the traj_key
+ try:
+ # In some cases, the storage hides the data behind "_data".
+ # In the future, this may be deprecated, and we don't want to mess
+ # with the keys provided by the user so we fall back on a proxy to
+ # the traj key.
+ try:
+ trajectory = storage._storage.get(self._used_traj_key)
+ except KeyError:
+ trajectory = storage._storage.get(("_data", self.traj_key))
+ # cache that value for future use
+ self._used_traj_key = ("_data", self.traj_key)
+ self._uses_data_prefix = (
+ isinstance(self._used_traj_key, tuple)
+ and self._used_traj_key[0] == "_data"
+ )
+ vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
+ if self.cache_values:
+ self._cache["stop-and-length"] = vals
+ return vals
+ except KeyError:
+ if fallback:
+ self._fetch_traj = False
+ return self._get_stop_and_length(storage, fallback=False)
+ raise
+
+ else:
+ try:
+ # In some cases, the storage hides the data behind "_data".
+ # In the future, this may be deprecated, and we don't want to mess
+ # with the keys provided by the user so we fall back on a proxy to
+ # the traj key.
+ try:
+ done = storage._storage.get(self._used_end_key)
+ except KeyError:
+ done = storage._storage.get(("_data", self.end_key))
+ # cache that value for future use
+ self._used_end_key = ("_data", self.end_key)
+ self._uses_data_prefix = (
+ isinstance(self._used_end_key, tuple)
+ and self._used_end_key[0] == "_data"
+ )
+ vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
+ if self.cache_values:
+ self._cache["stop-and-length"] = vals
+ return vals
+ except KeyError:
+ if fallback:
+ self._fetch_traj = True
+ return self._get_stop_and_length(storage, fallback=False)
+ raise
+
+ def _adjusted_batch_size(self, batch_size):
+ if self.num_slices is not None:
+ if batch_size % self.num_slices != 0:
+ raise RuntimeError(
+ f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
+ )
+ seq_length = batch_size // self.num_slices
+ num_slices = self.num_slices
+ else:
+ if batch_size % self.slice_len != 0:
+ raise RuntimeError(
+ f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
+ )
+ seq_length = self.slice_len
+ num_slices = batch_size // self.slice_len
+ return seq_length, num_slices
+
+ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
+ if not isinstance(storage, TensorStorage):
+ raise RuntimeError(
+ f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
+ )
+
+ # pick up as many trajs as we need
+ start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
+ seq_length, num_slices = self._adjusted_batch_size(batch_size)
+ return self._sample_slices(lengths, start_idx, seq_length, num_slices)
+
+ def _sample_slices(
+ self, lengths, start_idx, seq_length, num_slices, traj_idx=None
+ ) -> Tuple[torch.Tensor, dict]:
+ if (lengths < seq_length).any():
+ if self.strict_length:
+ raise RuntimeError(
+ "Some stored trajectories have a length shorter than the slice that was asked for. "
+ "Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
+ "in you batch."
+ )
+ # make seq_length a tensor with values clamped by lengths
+ seq_length = lengths.clamp_max(seq_length)
+
+ if traj_idx is None:
+ traj_idx = torch.randint(
+ lengths.shape[0], (num_slices,), device=lengths.device
+ )
+ else:
+ num_slices = traj_idx.shape[0]
+ relative_starts = (
+ (
+ torch.rand(num_slices, device=lengths.device)
+ * (lengths[traj_idx] - seq_length)
+ )
+ .floor()
+ .to(start_idx.dtype)
+ )
+ starts = start_idx[traj_idx] + relative_starts
+ index = self._tensor_slices_from_startend(seq_length, starts)
+ if self.truncated_key is not None:
+ truncated_key = self.truncated_key
+
+ truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
+ if isinstance(seq_length, int):
+ truncated.view(num_slices, -1)[:, -1] = 1
+ else:
+ truncated[seq_length.cumsum(0) - 1] = 1
+ return index.to(torch.long), {truncated_key: truncated}
+ return index.to(torch.long), {}
+
+ @property
+ def _used_traj_key(self):
+ return self.__dict__.get("__used_traj_key", self.traj_key)
+
+ @_used_traj_key.setter
+ def _used_traj_key(self, value):
+ self.__dict__["__used_traj_key"] = value
+
+ @property
+ def _used_end_key(self):
+ return self.__dict__.get("__used_end_key", self.end_key)
+
+ @_used_end_key.setter
+ def _used_end_key(self, value):
+ self.__dict__["__used_end_key"] = value
+
+ def _empty(self):
+ pass
+
+ def dumps(self, path):
+ # no op - cache does not need to be saved
+ ...
+
+ def loads(self, path):
+ # no op
+ ...
+
+ def __getstate__(self):
+ state = copy(self.__dict__)
+ state["_cache"] = {}
+ return state
diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py
index dfac4b5..6326a9e 100644
--- a/tdmpc2/envs/__init__.py
+++ b/tdmpc2/envs/__init__.py
@@ -6,11 +6,27 @@ import gym
from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper
-from envs.dmcontrol import make_env as make_dm_control_env
-from envs.maniskill import make_env as make_maniskill_env
-from envs.metaworld import make_env as make_metaworld_env
-from envs.myosuite import make_env as make_myosuite_env
-from envs.exceptions import UnknownTaskError
+
+def missing_dependencies(task):
+ raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
+
+try:
+ from envs.dmcontrol import make_env as make_dm_control_env
+except:
+ make_dm_control_env = missing_dependencies
+try:
+ from envs.maniskill import make_env as make_maniskill_env
+except:
+ make_maniskill_env = missing_dependencies
+try:
+ from envs.metaworld import make_env as make_metaworld_env
+except:
+ make_metaworld_env = missing_dependencies
+try:
+ from envs.myosuite import make_env as make_myosuite_env
+except:
+ make_myosuite_env = missing_dependencies
+
warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -27,7 +43,7 @@ def make_multitask_env(cfg):
_cfg.multitask = False
env = make_env(_cfg)
if env is None:
- raise UnknownTaskError(task)
+ raise ValueError('Unknown task:', task)
envs.append(env)
env = MultitaskWrapper(cfg, envs)
cfg.obs_shapes = env._obs_dims
@@ -43,15 +59,16 @@ def make_env(cfg):
gym.logger.set_level(40)
if cfg.multitask:
env = make_multitask_env(cfg)
+
else:
env = None
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try:
env = fn(cfg)
- except UnknownTaskError:
+ except ValueError:
pass
if env is None:
- raise UnknownTaskError(cfg.task)
+ raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)
diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py
index 32cb4b6..97be75a 100644
--- a/tdmpc2/envs/dmcontrol.py
+++ b/tdmpc2/envs/dmcontrol.py
@@ -8,7 +8,6 @@ suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs
-from envs.exceptions import UnknownTaskError
import gym
@@ -187,7 +186,8 @@ def make_env(cfg):
domain, task = cfg.task.replace('-', '_').split('_', 1)
domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain)
if (domain, task) not in suite.ALL_TASKS:
- raise UnknownTaskError(cfg.task)
+ raise ValueError('Unknown task:', task)
+ assert cfg.obs in {'state', 'rgb'}, 'This task only supports state and rgb observations.'
env = suite.load(domain,
task,
task_kwargs={'random': cfg.seed},
diff --git a/tdmpc2/envs/exceptions.py b/tdmpc2/envs/exceptions.py
deleted file mode 100644
index 9bf1390..0000000
--- a/tdmpc2/envs/exceptions.py
+++ /dev/null
@@ -1,4 +0,0 @@
-
-class UnknownTaskError(Exception):
- def __init__(self, task):
- super().__init__(f'Unknown task: {task}')
diff --git a/tdmpc2/envs/maniskill.py b/tdmpc2/envs/maniskill.py
index 1d2e4c9..7b0b6ed 100644
--- a/tdmpc2/envs/maniskill.py
+++ b/tdmpc2/envs/maniskill.py
@@ -1,7 +1,6 @@
import gym
import numpy as np
from envs.wrappers.time_limit import TimeLimit
-from envs.exceptions import UnknownTaskError
import mani_skill2.envs
@@ -65,7 +64,8 @@ def make_env(cfg):
Make ManiSkill2 environment.
"""
if cfg.task not in MANISKILL_TASKS:
- raise UnknownTaskError(cfg.task)
+ raise ValueError('Unknown task:', cfg.task)
+ assert cfg.obs == 'state', 'This task only supports state observations.'
task_cfg = MANISKILL_TASKS[cfg.task]
env = gym.make(
task_cfg['env'],
diff --git a/tdmpc2/envs/metaworld.py b/tdmpc2/envs/metaworld.py
index fd7379d..f5f4f0d 100644
--- a/tdmpc2/envs/metaworld.py
+++ b/tdmpc2/envs/metaworld.py
@@ -1,7 +1,6 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
-from envs.exceptions import UnknownTaskError
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
@@ -44,7 +43,8 @@ def make_env(cfg):
"""
env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable"
if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE:
- raise UnknownTaskError(cfg.task)
+ raise ValueError('Unknown task:', cfg.task)
+ assert cfg.obs == 'state', 'This task only supports state observations.'
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
env = MetaWorldWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)
diff --git a/tdmpc2/envs/myosuite.py b/tdmpc2/envs/myosuite.py
index c503782..fa6876e 100644
--- a/tdmpc2/envs/myosuite.py
+++ b/tdmpc2/envs/myosuite.py
@@ -1,24 +1,19 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
-from envs.exceptions import UnknownTaskError
MYOSUITE_TASKS = {
- 'myo-finger-reach': 'myoFingerReachFixed-v0',
- 'myo-finger-reach-hard': 'myoFingerReachRandom-v0',
- 'myo-finger-pose': 'myoFingerPoseFixed-v0',
- 'myo-finger-pose-hard': 'myoFingerPoseRandom-v0',
- 'myo-hand-reach': 'myoHandReachFixed-v0',
- 'myo-hand-reach-hard': 'myoHandReachRandom-v0',
- 'myo-hand-pose': 'myoHandPoseFixed-v0',
- 'myo-hand-pose-hard': 'myoHandPoseRandom-v0',
- 'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0',
- 'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0',
- 'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0',
- 'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0',
- 'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0',
- 'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
+ 'myo-reach': 'myoHandReachFixed-v0',
+ 'myo-reach-hard': 'myoHandReachRandom-v0',
+ 'myo-pose': 'myoHandPoseFixed-v0',
+ 'myo-pose-hard': 'myoHandPoseRandom-v0',
+ 'myo-obj-hold': 'myoHandObjHoldFixed-v0',
+ 'myo-obj-hold-hard': 'myoHandObjHoldRandom-v0',
+ 'myo-key-turn': 'myoHandKeyTurnFixed-v0',
+ 'myo-key-turn-hard': 'myoHandKeyTurnRandom-v0',
+ 'myo-pen-twirl': 'myoHandPenTwirlFixed-v0',
+ 'myo-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
}
@@ -50,7 +45,8 @@ def make_env(cfg):
Make Myosuite environment.
"""
if not cfg.task in MYOSUITE_TASKS:
- raise UnknownTaskError(cfg.task)
+ raise ValueError('Unknown task:', cfg.task)
+ assert cfg.obs == 'state', 'This task only supports state observations.'
import myosuite
env = gym.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg)
diff --git a/tdmpc2/train.py b/tdmpc2/train.py
index a35c11b..5953bb2 100755
--- a/tdmpc2/train.py
+++ b/tdmpc2/train.py
@@ -1,5 +1,6 @@
import os
os.environ['MUJOCO_GL'] = 'egl'
+os.environ['LAZY_LEGACY_OP'] = '0'
import warnings
warnings.filterwarnings('ignore')
import torch
diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py
index 94835ca..ca33009 100755
--- a/tdmpc2/trainer/online_trainer.py
+++ b/tdmpc2/trainer/online_trainer.py
@@ -50,11 +50,11 @@ class OnlineTrainer(Trainer):
def to_td(self, obs, action=None, reward=None):
"""Creates a TensorDict for a new episode."""
if isinstance(obs, dict):
- obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu()
+ obs = TensorDict(obs, batch_size=(), device='cpu')
else:
obs = obs.unsqueeze(0).cpu()
if action is None:
- action = torch.empty_like(self.env.rand_act())
+ action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None:
reward = torch.tensor(float('nan'))
td = TensorDict(dict(