allow missing env dependencies + update readme
This commit is contained in:
14
README.md
14
README.md
@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
|
||||
|
||||
<img src="assets/8.png" width="100%" style="max-width: 640px"><br/>
|
||||
|
||||
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. We hope that this repository will serve as a useful community resource for future research on model-based RL.
|
||||
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
|
||||
|
||||
----
|
||||
|
||||
@@ -32,12 +32,15 @@ We provide a `Dockerfile` for easy installation. You can build the docker image
|
||||
cd docker && docker build . -t <user>/tdmpc2:0.1.0
|
||||
```
|
||||
|
||||
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running
|
||||
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:
|
||||
|
||||
```
|
||||
conda env create -f docker/environment.yaml
|
||||
conda env create -f docker/environment_minimal.yaml
|
||||
```
|
||||
|
||||
The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.
|
||||
|
||||
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
|
||||
|
||||
```
|
||||
@@ -72,11 +75,13 @@ This codebase currently supports **104** continuous control tasks from **DMContr
|
||||
| metaworld | mw-pick-place-wall
|
||||
| maniskill | pick-cube
|
||||
| maniskill | pick-ycb
|
||||
| myosuite | myo-hand-key-turn
|
||||
| myosuite | myo-hand-key-turn-hard
|
||||
| myosuite | myo-key-turn
|
||||
| myosuite | myo-key-turn-hard
|
||||
|
||||
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.
|
||||
|
||||
**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.
|
||||
|
||||
|
||||
## Example usage
|
||||
|
||||
@@ -102,6 +107,7 @@ See below examples on how to train TD-MPC**2** on a single task (online RL) and
|
||||
$ python train.py task=mt80 model_size=48 batch_size=1024
|
||||
$ python train.py task=mt30 model_size=317 batch_size=1024
|
||||
$ python train.py task=dog-run steps=7000000
|
||||
$ python train.py task=walker-walk obs=rgb
|
||||
```
|
||||
|
||||
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
|
||||
|
||||
@@ -26,6 +26,7 @@ dependencies:
|
||||
- hydra-core
|
||||
- hydra-submitit-launcher
|
||||
- submitit
|
||||
- pandas
|
||||
- patchelf
|
||||
- protobuf
|
||||
- tqdm
|
||||
|
||||
39
docker/environment_minimal.yaml
Normal file
39
docker/environment_minimal.yaml
Normal file
@@ -0,0 +1,39 @@
|
||||
name: tdmpc2
|
||||
channels:
|
||||
- pytorch-nightly
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.0
|
||||
- pytorch
|
||||
- torchvision
|
||||
- cudatoolkit=11.7
|
||||
- glew
|
||||
- glib
|
||||
- pip==21
|
||||
- pip:
|
||||
- absl-py
|
||||
- glfw
|
||||
- kornia
|
||||
- termcolor
|
||||
- gym==0.21.0
|
||||
- moviepy
|
||||
- ffmpeg
|
||||
- imageio
|
||||
- imageio-ffmpeg
|
||||
- omegaconf
|
||||
- hydra-core
|
||||
- hydra-submitit-launcher
|
||||
- submitit
|
||||
- pandas
|
||||
- patchelf
|
||||
- protobuf
|
||||
- tqdm
|
||||
- setuptools==65.5.0
|
||||
- "cython<3"
|
||||
- dm-control
|
||||
- pillow
|
||||
- tensordict-nightly
|
||||
- torchrl-nightly
|
||||
- wandb
|
||||
@@ -6,11 +6,27 @@ import gym
|
||||
from envs.wrappers.multitask import MultitaskWrapper
|
||||
from envs.wrappers.pixels import PixelWrapper
|
||||
from envs.wrappers.tensor import TensorWrapper
|
||||
from envs.dmcontrol import make_env as make_dm_control_env
|
||||
# from envs.maniskill import make_env as make_maniskill_env
|
||||
# from envs.metaworld import make_env as make_metaworld_env
|
||||
# from envs.myosuite import make_env as make_myosuite_env
|
||||
from envs.exceptions import UnknownTaskError
|
||||
|
||||
def missing_dependencies(task):
|
||||
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
|
||||
|
||||
try:
|
||||
from envs.dmcontrol import make_env as make_dm_control_env
|
||||
except:
|
||||
make_dm_control_env = missing_dependencies
|
||||
try:
|
||||
from envs.maniskill import make_env as make_maniskill_env
|
||||
except:
|
||||
make_maniskill_env = missing_dependencies
|
||||
try:
|
||||
from envs.metaworld import make_env as make_metaworld_env
|
||||
except:
|
||||
make_metaworld_env = missing_dependencies
|
||||
try:
|
||||
from envs.myosuite import make_env as make_myosuite_env
|
||||
except:
|
||||
make_myosuite_env = missing_dependencies
|
||||
|
||||
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
|
||||
@@ -27,7 +43,7 @@ def make_multitask_env(cfg):
|
||||
_cfg.multitask = False
|
||||
env = make_env(_cfg)
|
||||
if env is None:
|
||||
raise UnknownTaskError(task)
|
||||
raise ValueError('Unknown task:', task)
|
||||
envs.append(env)
|
||||
env = MultitaskWrapper(cfg, envs)
|
||||
cfg.obs_shapes = env._obs_dims
|
||||
@@ -43,15 +59,16 @@ def make_env(cfg):
|
||||
gym.logger.set_level(40)
|
||||
if cfg.multitask:
|
||||
env = make_multitask_env(cfg)
|
||||
|
||||
else:
|
||||
env = None
|
||||
for fn in [make_dm_control_env]: #, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
||||
try:
|
||||
env = fn(cfg)
|
||||
except UnknownTaskError:
|
||||
except ValueError:
|
||||
pass
|
||||
if env is None:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
||||
env = TensorWrapper(env)
|
||||
if cfg.get('obs', 'state') == 'rgb':
|
||||
env = PixelWrapper(cfg, env)
|
||||
|
||||
@@ -8,7 +8,6 @@ suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
||||
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
||||
from dm_control.suite.wrappers import action_scale
|
||||
from dm_env import StepType, specs
|
||||
from envs.exceptions import UnknownTaskError
|
||||
import gym
|
||||
|
||||
|
||||
@@ -187,7 +186,8 @@ def make_env(cfg):
|
||||
domain, task = cfg.task.replace('-', '_').split('_', 1)
|
||||
domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain)
|
||||
if (domain, task) not in suite.ALL_TASKS:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
raise ValueError('Unknown task:', task)
|
||||
assert cfg.obs in {'state', 'rgb'}, 'This task only supports state and rgb observations.'
|
||||
env = suite.load(domain,
|
||||
task,
|
||||
task_kwargs={'random': cfg.seed},
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
|
||||
class UnknownTaskError(Exception):
|
||||
def __init__(self, task):
|
||||
super().__init__(f'Unknown task: {task}')
|
||||
@@ -1,7 +1,6 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from envs.wrappers.time_limit import TimeLimit
|
||||
from envs.exceptions import UnknownTaskError
|
||||
|
||||
import mani_skill2.envs
|
||||
|
||||
@@ -65,7 +64,8 @@ def make_env(cfg):
|
||||
Make ManiSkill2 environment.
|
||||
"""
|
||||
if cfg.task not in MANISKILL_TASKS:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
raise ValueError('Unknown task:', cfg.task)
|
||||
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||
task_cfg = MANISKILL_TASKS[cfg.task]
|
||||
env = gym.make(
|
||||
task_cfg['env'],
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from envs.wrappers.time_limit import TimeLimit
|
||||
from envs.exceptions import UnknownTaskError
|
||||
|
||||
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
|
||||
|
||||
@@ -44,7 +43,8 @@ def make_env(cfg):
|
||||
"""
|
||||
env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable"
|
||||
if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
raise ValueError('Unknown task:', cfg.task)
|
||||
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
|
||||
env = MetaWorldWrapper(env, cfg)
|
||||
env = TimeLimit(env, max_episode_steps=100)
|
||||
|
||||
@@ -1,24 +1,19 @@
|
||||
import numpy as np
|
||||
import gym
|
||||
from envs.wrappers.time_limit import TimeLimit
|
||||
from envs.exceptions import UnknownTaskError
|
||||
|
||||
|
||||
MYOSUITE_TASKS = {
|
||||
'myo-finger-reach': 'myoFingerReachFixed-v0',
|
||||
'myo-finger-reach-hard': 'myoFingerReachRandom-v0',
|
||||
'myo-finger-pose': 'myoFingerPoseFixed-v0',
|
||||
'myo-finger-pose-hard': 'myoFingerPoseRandom-v0',
|
||||
'myo-hand-reach': 'myoHandReachFixed-v0',
|
||||
'myo-hand-reach-hard': 'myoHandReachRandom-v0',
|
||||
'myo-hand-pose': 'myoHandPoseFixed-v0',
|
||||
'myo-hand-pose-hard': 'myoHandPoseRandom-v0',
|
||||
'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0',
|
||||
'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0',
|
||||
'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0',
|
||||
'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0',
|
||||
'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0',
|
||||
'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
|
||||
'myo-reach': 'myoHandReachFixed-v0',
|
||||
'myo-reach-hard': 'myoHandReachRandom-v0',
|
||||
'myo-pose': 'myoHandPoseFixed-v0',
|
||||
'myo-pose-hard': 'myoHandPoseRandom-v0',
|
||||
'myo-obj-hold': 'myoHandObjHoldFixed-v0',
|
||||
'myo-obj-hold-hard': 'myoHandObjHoldRandom-v0',
|
||||
'myo-key-turn': 'myoHandKeyTurnFixed-v0',
|
||||
'myo-key-turn-hard': 'myoHandKeyTurnRandom-v0',
|
||||
'myo-pen-twirl': 'myoHandPenTwirlFixed-v0',
|
||||
'myo-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +45,8 @@ def make_env(cfg):
|
||||
Make Myosuite environment.
|
||||
"""
|
||||
if not cfg.task in MYOSUITE_TASKS:
|
||||
raise UnknownTaskError(cfg.task)
|
||||
raise ValueError('Unknown task:', cfg.task)
|
||||
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||
import myosuite
|
||||
env = gym.make(MYOSUITE_TASKS[cfg.task])
|
||||
env = MyoSuiteWrapper(env, cfg)
|
||||
|
||||
@@ -54,7 +54,7 @@ class OnlineTrainer(Trainer):
|
||||
else:
|
||||
obs = obs.unsqueeze(0).cpu()
|
||||
if action is None:
|
||||
action = torch.empty_like(self.env.rand_act())
|
||||
action = torch.full_like(self.env.rand_act(), float('nan'))
|
||||
if reward is None:
|
||||
reward = torch.tensor(float('nan'))
|
||||
td = TensorDict(dict(
|
||||
|
||||
Reference in New Issue
Block a user