diff --git a/README.md b/README.md index a3ee4f1..94534a4 100755 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
-This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. We hope that this repository will serve as a useful community resource for future research on model-based RL. +This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL. ---- @@ -32,12 +32,15 @@ We provide a `Dockerfile` for easy installation. You can build the docker image cd docker && docker build . -t /tdmpc2:0.1.0 ``` -If you prefer to install dependencies manually, start by installing dependencies via `conda` by running +If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands: ``` conda env create -f docker/environment.yaml +conda env create -f docker/environment_minimal.yaml ``` +The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks. + If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running ``` @@ -72,11 +75,13 @@ This codebase currently supports **104** continuous control tasks from **DMContr | metaworld | mw-pick-place-wall | maniskill | pick-cube | maniskill | pick-ycb -| myosuite | myo-hand-key-turn -| myosuite | myo-hand-key-turn-hard +| myosuite | myo-key-turn +| myosuite | myo-key-turn-hard which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. +**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies. + ## Example usage @@ -102,6 +107,7 @@ See below examples on how to train TD-MPC**2** on a single task (online RL) and $ python train.py task=mt80 model_size=48 batch_size=1024 $ python train.py task=mt30 model_size=317 batch_size=1024 $ python train.py task=dog-run steps=7000000 +$ python train.py task=walker-walk obs=rgb ``` We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments. diff --git a/docker/environment.yaml b/docker/environment.yaml index 18a9914..6792839 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -26,6 +26,7 @@ dependencies: - hydra-core - hydra-submitit-launcher - submitit + - pandas - patchelf - protobuf - tqdm diff --git a/docker/environment_minimal.yaml b/docker/environment_minimal.yaml new file mode 100644 index 0000000..fbe30f6 --- /dev/null +++ b/docker/environment_minimal.yaml @@ -0,0 +1,39 @@ +name: tdmpc2 +channels: + - pytorch-nightly + - nvidia + - conda-forge + - defaults +dependencies: + - python=3.9.0 + - pytorch + - torchvision + - cudatoolkit=11.7 + - glew + - glib + - pip==21 + - pip: + - absl-py + - glfw + - kornia + - termcolor + - gym==0.21.0 + - moviepy + - ffmpeg + - imageio + - imageio-ffmpeg + - omegaconf + - hydra-core + - hydra-submitit-launcher + - submitit + - pandas + - patchelf + - protobuf + - tqdm + - setuptools==65.5.0 + - "cython<3" + - dm-control + - pillow + - tensordict-nightly + - torchrl-nightly + - wandb diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 5efcb73..6326a9e 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -6,11 +6,27 @@ import gym from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.pixels import PixelWrapper from envs.wrappers.tensor import TensorWrapper -from envs.dmcontrol import make_env as make_dm_control_env -# from envs.maniskill import make_env as make_maniskill_env -# from envs.metaworld import make_env as make_metaworld_env -# from envs.myosuite import make_env as make_myosuite_env -from envs.exceptions import UnknownTaskError + +def missing_dependencies(task): + raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') + +try: + from envs.dmcontrol import make_env as make_dm_control_env +except: + make_dm_control_env = missing_dependencies +try: + from envs.maniskill import make_env as make_maniskill_env +except: + make_maniskill_env = missing_dependencies +try: + from envs.metaworld import make_env as make_metaworld_env +except: + make_metaworld_env = missing_dependencies +try: + from envs.myosuite import make_env as make_myosuite_env +except: + make_myosuite_env = missing_dependencies + warnings.filterwarnings('ignore', category=DeprecationWarning) @@ -27,7 +43,7 @@ def make_multitask_env(cfg): _cfg.multitask = False env = make_env(_cfg) if env is None: - raise UnknownTaskError(task) + raise ValueError('Unknown task:', task) envs.append(env) env = MultitaskWrapper(cfg, envs) cfg.obs_shapes = env._obs_dims @@ -43,15 +59,16 @@ def make_env(cfg): gym.logger.set_level(40) if cfg.multitask: env = make_multitask_env(cfg) + else: env = None - for fn in [make_dm_control_env]: #, make_maniskill_env, make_metaworld_env, make_myosuite_env]: + for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: try: env = fn(cfg) - except UnknownTaskError: + except ValueError: pass if env is None: - raise UnknownTaskError(cfg.task) + raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') env = TensorWrapper(env) if cfg.get('obs', 'state') == 'rgb': env = PixelWrapper(cfg, env) diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index 32cb4b6..97be75a 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -8,7 +8,6 @@ suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom') suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS) from dm_control.suite.wrappers import action_scale from dm_env import StepType, specs -from envs.exceptions import UnknownTaskError import gym @@ -187,7 +186,8 @@ def make_env(cfg): domain, task = cfg.task.replace('-', '_').split('_', 1) domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain) if (domain, task) not in suite.ALL_TASKS: - raise UnknownTaskError(cfg.task) + raise ValueError('Unknown task:', task) + assert cfg.obs in {'state', 'rgb'}, 'This task only supports state and rgb observations.' env = suite.load(domain, task, task_kwargs={'random': cfg.seed}, diff --git a/tdmpc2/envs/exceptions.py b/tdmpc2/envs/exceptions.py deleted file mode 100644 index 9bf1390..0000000 --- a/tdmpc2/envs/exceptions.py +++ /dev/null @@ -1,4 +0,0 @@ - -class UnknownTaskError(Exception): - def __init__(self, task): - super().__init__(f'Unknown task: {task}') diff --git a/tdmpc2/envs/maniskill.py b/tdmpc2/envs/maniskill.py index 1d2e4c9..7b0b6ed 100644 --- a/tdmpc2/envs/maniskill.py +++ b/tdmpc2/envs/maniskill.py @@ -1,7 +1,6 @@ import gym import numpy as np from envs.wrappers.time_limit import TimeLimit -from envs.exceptions import UnknownTaskError import mani_skill2.envs @@ -65,7 +64,8 @@ def make_env(cfg): Make ManiSkill2 environment. """ if cfg.task not in MANISKILL_TASKS: - raise UnknownTaskError(cfg.task) + raise ValueError('Unknown task:', cfg.task) + assert cfg.obs == 'state', 'This task only supports state observations.' task_cfg = MANISKILL_TASKS[cfg.task] env = gym.make( task_cfg['env'], diff --git a/tdmpc2/envs/metaworld.py b/tdmpc2/envs/metaworld.py index fd7379d..f5f4f0d 100644 --- a/tdmpc2/envs/metaworld.py +++ b/tdmpc2/envs/metaworld.py @@ -1,7 +1,6 @@ import numpy as np import gym from envs.wrappers.time_limit import TimeLimit -from envs.exceptions import UnknownTaskError from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE @@ -44,7 +43,8 @@ def make_env(cfg): """ env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable" if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE: - raise UnknownTaskError(cfg.task) + raise ValueError('Unknown task:', cfg.task) + assert cfg.obs == 'state', 'This task only supports state observations.' env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed) env = MetaWorldWrapper(env, cfg) env = TimeLimit(env, max_episode_steps=100) diff --git a/tdmpc2/envs/myosuite.py b/tdmpc2/envs/myosuite.py index c503782..fa6876e 100644 --- a/tdmpc2/envs/myosuite.py +++ b/tdmpc2/envs/myosuite.py @@ -1,24 +1,19 @@ import numpy as np import gym from envs.wrappers.time_limit import TimeLimit -from envs.exceptions import UnknownTaskError MYOSUITE_TASKS = { - 'myo-finger-reach': 'myoFingerReachFixed-v0', - 'myo-finger-reach-hard': 'myoFingerReachRandom-v0', - 'myo-finger-pose': 'myoFingerPoseFixed-v0', - 'myo-finger-pose-hard': 'myoFingerPoseRandom-v0', - 'myo-hand-reach': 'myoHandReachFixed-v0', - 'myo-hand-reach-hard': 'myoHandReachRandom-v0', - 'myo-hand-pose': 'myoHandPoseFixed-v0', - 'myo-hand-pose-hard': 'myoHandPoseRandom-v0', - 'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0', - 'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0', - 'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0', - 'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0', - 'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0', - 'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0', + 'myo-reach': 'myoHandReachFixed-v0', + 'myo-reach-hard': 'myoHandReachRandom-v0', + 'myo-pose': 'myoHandPoseFixed-v0', + 'myo-pose-hard': 'myoHandPoseRandom-v0', + 'myo-obj-hold': 'myoHandObjHoldFixed-v0', + 'myo-obj-hold-hard': 'myoHandObjHoldRandom-v0', + 'myo-key-turn': 'myoHandKeyTurnFixed-v0', + 'myo-key-turn-hard': 'myoHandKeyTurnRandom-v0', + 'myo-pen-twirl': 'myoHandPenTwirlFixed-v0', + 'myo-pen-twirl-hard': 'myoHandPenTwirlRandom-v0', } @@ -50,7 +45,8 @@ def make_env(cfg): Make Myosuite environment. """ if not cfg.task in MYOSUITE_TASKS: - raise UnknownTaskError(cfg.task) + raise ValueError('Unknown task:', cfg.task) + assert cfg.obs == 'state', 'This task only supports state observations.' import myosuite env = gym.make(MYOSUITE_TASKS[cfg.task]) env = MyoSuiteWrapper(env, cfg) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index f5f65cc..ca33009 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -54,7 +54,7 @@ class OnlineTrainer(Trainer): else: obs = obs.unsqueeze(0).cpu() if action is None: - action = torch.empty_like(self.env.rand_act()) + action = torch.full_like(self.env.rand_act(), float('nan')) if reward is None: reward = torch.tensor(float('nan')) td = TensorDict(dict(