diff --git a/README.md b/README.md
index a3ee4f1..94534a4 100755
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.

-This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. We hope that this repository will serve as a useful community resource for future research on model-based RL.
+This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
----
@@ -32,12 +32,15 @@ We provide a `Dockerfile` for easy installation. You can build the docker image
cd docker && docker build . -t /tdmpc2:0.1.0
```
-If you prefer to install dependencies manually, start by installing dependencies via `conda` by running
+If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:
```
conda env create -f docker/environment.yaml
+conda env create -f docker/environment_minimal.yaml
```
+The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.
+
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
```
@@ -72,11 +75,13 @@ This codebase currently supports **104** continuous control tasks from **DMContr
| metaworld | mw-pick-place-wall
| maniskill | pick-cube
| maniskill | pick-ycb
-| myosuite | myo-hand-key-turn
-| myosuite | myo-hand-key-turn-hard
+| myosuite | myo-key-turn
+| myosuite | myo-key-turn-hard
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.
+**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.
+
## Example usage
@@ -102,6 +107,7 @@ See below examples on how to train TD-MPC**2** on a single task (online RL) and
$ python train.py task=mt80 model_size=48 batch_size=1024
$ python train.py task=mt30 model_size=317 batch_size=1024
$ python train.py task=dog-run steps=7000000
+$ python train.py task=walker-walk obs=rgb
```
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
diff --git a/docker/environment.yaml b/docker/environment.yaml
index 18a9914..6792839 100644
--- a/docker/environment.yaml
+++ b/docker/environment.yaml
@@ -26,6 +26,7 @@ dependencies:
- hydra-core
- hydra-submitit-launcher
- submitit
+ - pandas
- patchelf
- protobuf
- tqdm
diff --git a/docker/environment_minimal.yaml b/docker/environment_minimal.yaml
new file mode 100644
index 0000000..fbe30f6
--- /dev/null
+++ b/docker/environment_minimal.yaml
@@ -0,0 +1,39 @@
+name: tdmpc2
+channels:
+ - pytorch-nightly
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.9.0
+ - pytorch
+ - torchvision
+ - cudatoolkit=11.7
+ - glew
+ - glib
+ - pip==21
+ - pip:
+ - absl-py
+ - glfw
+ - kornia
+ - termcolor
+ - gym==0.21.0
+ - moviepy
+ - ffmpeg
+ - imageio
+ - imageio-ffmpeg
+ - omegaconf
+ - hydra-core
+ - hydra-submitit-launcher
+ - submitit
+ - pandas
+ - patchelf
+ - protobuf
+ - tqdm
+ - setuptools==65.5.0
+ - "cython<3"
+ - dm-control
+ - pillow
+ - tensordict-nightly
+ - torchrl-nightly
+ - wandb
diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py
index 5efcb73..6326a9e 100644
--- a/tdmpc2/envs/__init__.py
+++ b/tdmpc2/envs/__init__.py
@@ -6,11 +6,27 @@ import gym
from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper
-from envs.dmcontrol import make_env as make_dm_control_env
-# from envs.maniskill import make_env as make_maniskill_env
-# from envs.metaworld import make_env as make_metaworld_env
-# from envs.myosuite import make_env as make_myosuite_env
-from envs.exceptions import UnknownTaskError
+
+def missing_dependencies(task):
+ raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
+
+try:
+ from envs.dmcontrol import make_env as make_dm_control_env
+except:
+ make_dm_control_env = missing_dependencies
+try:
+ from envs.maniskill import make_env as make_maniskill_env
+except:
+ make_maniskill_env = missing_dependencies
+try:
+ from envs.metaworld import make_env as make_metaworld_env
+except:
+ make_metaworld_env = missing_dependencies
+try:
+ from envs.myosuite import make_env as make_myosuite_env
+except:
+ make_myosuite_env = missing_dependencies
+
warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -27,7 +43,7 @@ def make_multitask_env(cfg):
_cfg.multitask = False
env = make_env(_cfg)
if env is None:
- raise UnknownTaskError(task)
+ raise ValueError('Unknown task:', task)
envs.append(env)
env = MultitaskWrapper(cfg, envs)
cfg.obs_shapes = env._obs_dims
@@ -43,15 +59,16 @@ def make_env(cfg):
gym.logger.set_level(40)
if cfg.multitask:
env = make_multitask_env(cfg)
+
else:
env = None
- for fn in [make_dm_control_env]: #, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
+ for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try:
env = fn(cfg)
- except UnknownTaskError:
+ except ValueError:
pass
if env is None:
- raise UnknownTaskError(cfg.task)
+ raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)
diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py
index 32cb4b6..97be75a 100644
--- a/tdmpc2/envs/dmcontrol.py
+++ b/tdmpc2/envs/dmcontrol.py
@@ -8,7 +8,6 @@ suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs
-from envs.exceptions import UnknownTaskError
import gym
@@ -187,7 +186,8 @@ def make_env(cfg):
domain, task = cfg.task.replace('-', '_').split('_', 1)
domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain)
if (domain, task) not in suite.ALL_TASKS:
- raise UnknownTaskError(cfg.task)
+ raise ValueError('Unknown task:', task)
+ assert cfg.obs in {'state', 'rgb'}, 'This task only supports state and rgb observations.'
env = suite.load(domain,
task,
task_kwargs={'random': cfg.seed},
diff --git a/tdmpc2/envs/exceptions.py b/tdmpc2/envs/exceptions.py
deleted file mode 100644
index 9bf1390..0000000
--- a/tdmpc2/envs/exceptions.py
+++ /dev/null
@@ -1,4 +0,0 @@
-
-class UnknownTaskError(Exception):
- def __init__(self, task):
- super().__init__(f'Unknown task: {task}')
diff --git a/tdmpc2/envs/maniskill.py b/tdmpc2/envs/maniskill.py
index 1d2e4c9..7b0b6ed 100644
--- a/tdmpc2/envs/maniskill.py
+++ b/tdmpc2/envs/maniskill.py
@@ -1,7 +1,6 @@
import gym
import numpy as np
from envs.wrappers.time_limit import TimeLimit
-from envs.exceptions import UnknownTaskError
import mani_skill2.envs
@@ -65,7 +64,8 @@ def make_env(cfg):
Make ManiSkill2 environment.
"""
if cfg.task not in MANISKILL_TASKS:
- raise UnknownTaskError(cfg.task)
+ raise ValueError('Unknown task:', cfg.task)
+ assert cfg.obs == 'state', 'This task only supports state observations.'
task_cfg = MANISKILL_TASKS[cfg.task]
env = gym.make(
task_cfg['env'],
diff --git a/tdmpc2/envs/metaworld.py b/tdmpc2/envs/metaworld.py
index fd7379d..f5f4f0d 100644
--- a/tdmpc2/envs/metaworld.py
+++ b/tdmpc2/envs/metaworld.py
@@ -1,7 +1,6 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
-from envs.exceptions import UnknownTaskError
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
@@ -44,7 +43,8 @@ def make_env(cfg):
"""
env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable"
if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE:
- raise UnknownTaskError(cfg.task)
+ raise ValueError('Unknown task:', cfg.task)
+ assert cfg.obs == 'state', 'This task only supports state observations.'
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
env = MetaWorldWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)
diff --git a/tdmpc2/envs/myosuite.py b/tdmpc2/envs/myosuite.py
index c503782..fa6876e 100644
--- a/tdmpc2/envs/myosuite.py
+++ b/tdmpc2/envs/myosuite.py
@@ -1,24 +1,19 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
-from envs.exceptions import UnknownTaskError
MYOSUITE_TASKS = {
- 'myo-finger-reach': 'myoFingerReachFixed-v0',
- 'myo-finger-reach-hard': 'myoFingerReachRandom-v0',
- 'myo-finger-pose': 'myoFingerPoseFixed-v0',
- 'myo-finger-pose-hard': 'myoFingerPoseRandom-v0',
- 'myo-hand-reach': 'myoHandReachFixed-v0',
- 'myo-hand-reach-hard': 'myoHandReachRandom-v0',
- 'myo-hand-pose': 'myoHandPoseFixed-v0',
- 'myo-hand-pose-hard': 'myoHandPoseRandom-v0',
- 'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0',
- 'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0',
- 'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0',
- 'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0',
- 'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0',
- 'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
+ 'myo-reach': 'myoHandReachFixed-v0',
+ 'myo-reach-hard': 'myoHandReachRandom-v0',
+ 'myo-pose': 'myoHandPoseFixed-v0',
+ 'myo-pose-hard': 'myoHandPoseRandom-v0',
+ 'myo-obj-hold': 'myoHandObjHoldFixed-v0',
+ 'myo-obj-hold-hard': 'myoHandObjHoldRandom-v0',
+ 'myo-key-turn': 'myoHandKeyTurnFixed-v0',
+ 'myo-key-turn-hard': 'myoHandKeyTurnRandom-v0',
+ 'myo-pen-twirl': 'myoHandPenTwirlFixed-v0',
+ 'myo-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
}
@@ -50,7 +45,8 @@ def make_env(cfg):
Make Myosuite environment.
"""
if not cfg.task in MYOSUITE_TASKS:
- raise UnknownTaskError(cfg.task)
+ raise ValueError('Unknown task:', cfg.task)
+ assert cfg.obs == 'state', 'This task only supports state observations.'
import myosuite
env = gym.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg)
diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py
index f5f65cc..ca33009 100755
--- a/tdmpc2/trainer/online_trainer.py
+++ b/tdmpc2/trainer/online_trainer.py
@@ -54,7 +54,7 @@ class OnlineTrainer(Trainer):
else:
obs = obs.unsqueeze(0).cpu()
if action is None:
- action = torch.empty_like(self.env.rand_act())
+ action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None:
reward = torch.tensor(float('nan'))
td = TensorDict(dict(