allow missing env dependencies + update readme

This commit is contained in:
Nicklas Hansen
2023-12-28 07:33:03 -08:00
parent 54145a4d8c
commit 6cb779aa3a
10 changed files with 95 additions and 40 deletions

View File

@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
<img src="assets/8.png" width="100%" style="max-width: 640px"><br/> <img src="assets/8.png" width="100%" style="max-width: 640px"><br/>
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. We hope that this repository will serve as a useful community resource for future research on model-based RL. This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
---- ----
@@ -32,12 +32,15 @@ We provide a `Dockerfile` for easy installation. You can build the docker image
cd docker && docker build . -t <user>/tdmpc2:0.1.0 cd docker && docker build . -t <user>/tdmpc2:0.1.0
``` ```
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:
``` ```
conda env create -f docker/environment.yaml conda env create -f docker/environment.yaml
conda env create -f docker/environment_minimal.yaml
``` ```
The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
``` ```
@@ -72,11 +75,13 @@ This codebase currently supports **104** continuous control tasks from **DMContr
| metaworld | mw-pick-place-wall | metaworld | mw-pick-place-wall
| maniskill | pick-cube | maniskill | pick-cube
| maniskill | pick-ycb | maniskill | pick-ycb
| myosuite | myo-hand-key-turn | myosuite | myo-key-turn
| myosuite | myo-hand-key-turn-hard | myosuite | myo-key-turn-hard
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.
**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.
## Example usage ## Example usage
@@ -102,6 +107,7 @@ See below examples on how to train TD-MPC**2** on a single task (online RL) and
$ python train.py task=mt80 model_size=48 batch_size=1024 $ python train.py task=mt80 model_size=48 batch_size=1024
$ python train.py task=mt30 model_size=317 batch_size=1024 $ python train.py task=mt30 model_size=317 batch_size=1024
$ python train.py task=dog-run steps=7000000 $ python train.py task=dog-run steps=7000000
$ python train.py task=walker-walk obs=rgb
``` ```
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments. We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.

View File

@@ -26,6 +26,7 @@ dependencies:
- hydra-core - hydra-core
- hydra-submitit-launcher - hydra-submitit-launcher
- submitit - submitit
- pandas
- patchelf - patchelf
- protobuf - protobuf
- tqdm - tqdm

View File

@@ -0,0 +1,39 @@
name: tdmpc2
channels:
- pytorch-nightly
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7
- glew
- glib
- pip==21
- pip:
- absl-py
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- setuptools==65.5.0
- "cython<3"
- dm-control
- pillow
- tensordict-nightly
- torchrl-nightly
- wandb

View File

@@ -6,11 +6,27 @@ import gym
from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper from envs.wrappers.tensor import TensorWrapper
def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
try:
from envs.dmcontrol import make_env as make_dm_control_env from envs.dmcontrol import make_env as make_dm_control_env
# from envs.maniskill import make_env as make_maniskill_env except:
# from envs.metaworld import make_env as make_metaworld_env make_dm_control_env = missing_dependencies
# from envs.myosuite import make_env as make_myosuite_env try:
from envs.exceptions import UnknownTaskError from envs.maniskill import make_env as make_maniskill_env
except:
make_maniskill_env = missing_dependencies
try:
from envs.metaworld import make_env as make_metaworld_env
except:
make_metaworld_env = missing_dependencies
try:
from envs.myosuite import make_env as make_myosuite_env
except:
make_myosuite_env = missing_dependencies
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -27,7 +43,7 @@ def make_multitask_env(cfg):
_cfg.multitask = False _cfg.multitask = False
env = make_env(_cfg) env = make_env(_cfg)
if env is None: if env is None:
raise UnknownTaskError(task) raise ValueError('Unknown task:', task)
envs.append(env) envs.append(env)
env = MultitaskWrapper(cfg, envs) env = MultitaskWrapper(cfg, envs)
cfg.obs_shapes = env._obs_dims cfg.obs_shapes = env._obs_dims
@@ -43,15 +59,16 @@ def make_env(cfg):
gym.logger.set_level(40) gym.logger.set_level(40)
if cfg.multitask: if cfg.multitask:
env = make_multitask_env(cfg) env = make_multitask_env(cfg)
else: else:
env = None env = None
for fn in [make_dm_control_env]: #, make_maniskill_env, make_metaworld_env, make_myosuite_env]: for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try: try:
env = fn(cfg) env = fn(cfg)
except UnknownTaskError: except ValueError:
pass pass
if env is None: if env is None:
raise UnknownTaskError(cfg.task) raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = TensorWrapper(env) env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb': if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env) env = PixelWrapper(cfg, env)

View File

@@ -8,7 +8,6 @@ suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS) suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs from dm_env import StepType, specs
from envs.exceptions import UnknownTaskError
import gym import gym
@@ -187,7 +186,8 @@ def make_env(cfg):
domain, task = cfg.task.replace('-', '_').split('_', 1) domain, task = cfg.task.replace('-', '_').split('_', 1)
domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain) domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain)
if (domain, task) not in suite.ALL_TASKS: if (domain, task) not in suite.ALL_TASKS:
raise UnknownTaskError(cfg.task) raise ValueError('Unknown task:', task)
assert cfg.obs in {'state', 'rgb'}, 'This task only supports state and rgb observations.'
env = suite.load(domain, env = suite.load(domain,
task, task,
task_kwargs={'random': cfg.seed}, task_kwargs={'random': cfg.seed},

View File

@@ -1,4 +0,0 @@
class UnknownTaskError(Exception):
def __init__(self, task):
super().__init__(f'Unknown task: {task}')

View File

@@ -1,7 +1,6 @@
import gym import gym
import numpy as np import numpy as np
from envs.wrappers.time_limit import TimeLimit from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
import mani_skill2.envs import mani_skill2.envs
@@ -65,7 +64,8 @@ def make_env(cfg):
Make ManiSkill2 environment. Make ManiSkill2 environment.
""" """
if cfg.task not in MANISKILL_TASKS: if cfg.task not in MANISKILL_TASKS:
raise UnknownTaskError(cfg.task) raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
task_cfg = MANISKILL_TASKS[cfg.task] task_cfg = MANISKILL_TASKS[cfg.task]
env = gym.make( env = gym.make(
task_cfg['env'], task_cfg['env'],

View File

@@ -1,7 +1,6 @@
import numpy as np import numpy as np
import gym import gym
from envs.wrappers.time_limit import TimeLimit from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
@@ -44,7 +43,8 @@ def make_env(cfg):
""" """
env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable" env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable"
if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE: if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE:
raise UnknownTaskError(cfg.task) raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed) env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
env = MetaWorldWrapper(env, cfg) env = MetaWorldWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100) env = TimeLimit(env, max_episode_steps=100)

View File

@@ -1,24 +1,19 @@
import numpy as np import numpy as np
import gym import gym
from envs.wrappers.time_limit import TimeLimit from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
MYOSUITE_TASKS = { MYOSUITE_TASKS = {
'myo-finger-reach': 'myoFingerReachFixed-v0', 'myo-reach': 'myoHandReachFixed-v0',
'myo-finger-reach-hard': 'myoFingerReachRandom-v0', 'myo-reach-hard': 'myoHandReachRandom-v0',
'myo-finger-pose': 'myoFingerPoseFixed-v0', 'myo-pose': 'myoHandPoseFixed-v0',
'myo-finger-pose-hard': 'myoFingerPoseRandom-v0', 'myo-pose-hard': 'myoHandPoseRandom-v0',
'myo-hand-reach': 'myoHandReachFixed-v0', 'myo-obj-hold': 'myoHandObjHoldFixed-v0',
'myo-hand-reach-hard': 'myoHandReachRandom-v0', 'myo-obj-hold-hard': 'myoHandObjHoldRandom-v0',
'myo-hand-pose': 'myoHandPoseFixed-v0', 'myo-key-turn': 'myoHandKeyTurnFixed-v0',
'myo-hand-pose-hard': 'myoHandPoseRandom-v0', 'myo-key-turn-hard': 'myoHandKeyTurnRandom-v0',
'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0', 'myo-pen-twirl': 'myoHandPenTwirlFixed-v0',
'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0', 'myo-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0',
'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0',
'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0',
'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
} }
@@ -50,7 +45,8 @@ def make_env(cfg):
Make Myosuite environment. Make Myosuite environment.
""" """
if not cfg.task in MYOSUITE_TASKS: if not cfg.task in MYOSUITE_TASKS:
raise UnknownTaskError(cfg.task) raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
import myosuite import myosuite
env = gym.make(MYOSUITE_TASKS[cfg.task]) env = gym.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg) env = MyoSuiteWrapper(env, cfg)

View File

@@ -54,7 +54,7 @@ class OnlineTrainer(Trainer):
else: else:
obs = obs.unsqueeze(0).cpu() obs = obs.unsqueeze(0).cpu()
if action is None: if action is None:
action = torch.empty_like(self.env.rand_act()) action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan'))
td = TensorDict(dict( td = TensorDict(dict(