allow missing env dependencies + update readme

This commit is contained in:
Nicklas Hansen
2023-12-28 07:33:03 -08:00
parent 54145a4d8c
commit 6cb779aa3a
10 changed files with 95 additions and 40 deletions

View File

@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
<img src="assets/8.png" width="100%" style="max-width: 640px"><br/>
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. We hope that this repository will serve as a useful community resource for future research on model-based RL.
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
----
@@ -32,12 +32,15 @@ We provide a `Dockerfile` for easy installation. You can build the docker image
cd docker && docker build . -t <user>/tdmpc2:0.1.0
```
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:
```
conda env create -f docker/environment.yaml
conda env create -f docker/environment_minimal.yaml
```
The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
```
@@ -72,11 +75,13 @@ This codebase currently supports **104** continuous control tasks from **DMContr
| metaworld | mw-pick-place-wall
| maniskill | pick-cube
| maniskill | pick-ycb
| myosuite | myo-hand-key-turn
| myosuite | myo-hand-key-turn-hard
| myosuite | myo-key-turn
| myosuite | myo-key-turn-hard
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.
**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.
## Example usage
@@ -102,6 +107,7 @@ See below examples on how to train TD-MPC**2** on a single task (online RL) and
$ python train.py task=mt80 model_size=48 batch_size=1024
$ python train.py task=mt30 model_size=317 batch_size=1024
$ python train.py task=dog-run steps=7000000
$ python train.py task=walker-walk obs=rgb
```
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.

View File

@@ -26,6 +26,7 @@ dependencies:
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm

View File

@@ -0,0 +1,39 @@
name: tdmpc2
channels:
- pytorch-nightly
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7
- glew
- glib
- pip==21
- pip:
- absl-py
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- setuptools==65.5.0
- "cython<3"
- dm-control
- pillow
- tensordict-nightly
- torchrl-nightly
- wandb

View File

@@ -6,11 +6,27 @@ import gym
from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper
def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
try:
from envs.dmcontrol import make_env as make_dm_control_env
# from envs.maniskill import make_env as make_maniskill_env
# from envs.metaworld import make_env as make_metaworld_env
# from envs.myosuite import make_env as make_myosuite_env
from envs.exceptions import UnknownTaskError
except:
make_dm_control_env = missing_dependencies
try:
from envs.maniskill import make_env as make_maniskill_env
except:
make_maniskill_env = missing_dependencies
try:
from envs.metaworld import make_env as make_metaworld_env
except:
make_metaworld_env = missing_dependencies
try:
from envs.myosuite import make_env as make_myosuite_env
except:
make_myosuite_env = missing_dependencies
warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -27,7 +43,7 @@ def make_multitask_env(cfg):
_cfg.multitask = False
env = make_env(_cfg)
if env is None:
raise UnknownTaskError(task)
raise ValueError('Unknown task:', task)
envs.append(env)
env = MultitaskWrapper(cfg, envs)
cfg.obs_shapes = env._obs_dims
@@ -43,15 +59,16 @@ def make_env(cfg):
gym.logger.set_level(40)
if cfg.multitask:
env = make_multitask_env(cfg)
else:
env = None
for fn in [make_dm_control_env]: #, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try:
env = fn(cfg)
except UnknownTaskError:
except ValueError:
pass
if env is None:
raise UnknownTaskError(cfg.task)
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)

View File

@@ -8,7 +8,6 @@ suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs
from envs.exceptions import UnknownTaskError
import gym
@@ -187,7 +186,8 @@ def make_env(cfg):
domain, task = cfg.task.replace('-', '_').split('_', 1)
domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain)
if (domain, task) not in suite.ALL_TASKS:
raise UnknownTaskError(cfg.task)
raise ValueError('Unknown task:', task)
assert cfg.obs in {'state', 'rgb'}, 'This task only supports state and rgb observations.'
env = suite.load(domain,
task,
task_kwargs={'random': cfg.seed},

View File

@@ -1,4 +0,0 @@
class UnknownTaskError(Exception):
def __init__(self, task):
super().__init__(f'Unknown task: {task}')

View File

@@ -1,7 +1,6 @@
import gym
import numpy as np
from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
import mani_skill2.envs
@@ -65,7 +64,8 @@ def make_env(cfg):
Make ManiSkill2 environment.
"""
if cfg.task not in MANISKILL_TASKS:
raise UnknownTaskError(cfg.task)
raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
task_cfg = MANISKILL_TASKS[cfg.task]
env = gym.make(
task_cfg['env'],

View File

@@ -1,7 +1,6 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
@@ -44,7 +43,8 @@ def make_env(cfg):
"""
env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable"
if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE:
raise UnknownTaskError(cfg.task)
raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
env = MetaWorldWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)

View File

@@ -1,24 +1,19 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
MYOSUITE_TASKS = {
'myo-finger-reach': 'myoFingerReachFixed-v0',
'myo-finger-reach-hard': 'myoFingerReachRandom-v0',
'myo-finger-pose': 'myoFingerPoseFixed-v0',
'myo-finger-pose-hard': 'myoFingerPoseRandom-v0',
'myo-hand-reach': 'myoHandReachFixed-v0',
'myo-hand-reach-hard': 'myoHandReachRandom-v0',
'myo-hand-pose': 'myoHandPoseFixed-v0',
'myo-hand-pose-hard': 'myoHandPoseRandom-v0',
'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0',
'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0',
'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0',
'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0',
'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0',
'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
'myo-reach': 'myoHandReachFixed-v0',
'myo-reach-hard': 'myoHandReachRandom-v0',
'myo-pose': 'myoHandPoseFixed-v0',
'myo-pose-hard': 'myoHandPoseRandom-v0',
'myo-obj-hold': 'myoHandObjHoldFixed-v0',
'myo-obj-hold-hard': 'myoHandObjHoldRandom-v0',
'myo-key-turn': 'myoHandKeyTurnFixed-v0',
'myo-key-turn-hard': 'myoHandKeyTurnRandom-v0',
'myo-pen-twirl': 'myoHandPenTwirlFixed-v0',
'myo-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
}
@@ -50,7 +45,8 @@ def make_env(cfg):
Make Myosuite environment.
"""
if not cfg.task in MYOSUITE_TASKS:
raise UnknownTaskError(cfg.task)
raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
import myosuite
env = gym.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg)

View File

@@ -54,7 +54,7 @@ class OnlineTrainer(Trainer):
else:
obs = obs.unsqueeze(0).cpu()
if action is None:
action = torch.empty_like(self.env.rand_act())
action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None:
reward = torch.tensor(float('nan'))
td = TensorDict(dict(