update readme + clean up
This commit is contained in:
12
README.md
12
README.md
@@ -12,9 +12,9 @@ Official implementation of
|
||||
|
||||
----
|
||||
|
||||
**Announcement: training just got ~4.5x faster!**
|
||||
**Announcement (Apr 2025): support for episodic tasks!**
|
||||
|
||||
Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install recent `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility!
|
||||
We have added support for episodic RL (tasks with terminations) in the latest release. This functionality can be enabled with `episodic=true` but remains disabled by default to ensure reproducibility of results across releases.
|
||||
|
||||
----
|
||||
|
||||
@@ -74,7 +74,7 @@ See `docker/Dockerfile` for installation instructions if you do not already have
|
||||
|
||||
## Supported tasks
|
||||
|
||||
This codebase currently supports **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite**. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain:
|
||||
This codebase provides support for all **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite** used in our paper. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain:
|
||||
|
||||
| domain | task
|
||||
| --- | --- |
|
||||
@@ -87,9 +87,9 @@ This codebase currently supports **104** continuous control tasks from **DMContr
|
||||
| myosuite | myo-key-turn
|
||||
| myosuite | myo-key-turn-hard
|
||||
|
||||
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.
|
||||
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. We also provide basic support for other MuJoCo/Box2d Gymnasium tasks; refer to the `envs` directory for a list of tasks. It should be relatively straightforward to add support for custom tasks by following the examples in `envs`.
|
||||
|
||||
**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.
|
||||
**Note:** we also provide support for image observations in the DMControl tasks. Use argument `obs=rgb` if you wish to train visual policies.
|
||||
|
||||
|
||||
## Example usage
|
||||
@@ -121,8 +121,6 @@ $ python train.py task=walker-walk obs=rgb
|
||||
|
||||
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
|
||||
|
||||
**As of Jan 7, 2024 the TD-MPC2 codebase also supports multi-GPU training for multi-task offline RL experiments**; use branch `distributed` and argument `world_size=N` to train on `N` GPUs. We cannot guarantee that distributed training will yield the same results, but they appear to be similar based on our limited testing.
|
||||
|
||||
----
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -71,6 +71,7 @@ enable_wandb: true
|
||||
save_csv: true
|
||||
|
||||
# misc
|
||||
compile: true
|
||||
save_video: true
|
||||
save_agent: true
|
||||
seed: 1
|
||||
@@ -88,6 +89,3 @@ action_dims: ???
|
||||
episode_lengths: ???
|
||||
seed_steps: ???
|
||||
bin_size: ???
|
||||
|
||||
# speedups
|
||||
compile: false
|
||||
|
||||
@@ -131,7 +131,8 @@ class TDMPC2(torch.nn.Module):
|
||||
G = G + discount * (1-termination) * reward
|
||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||
discount = discount * discount_update
|
||||
termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.)
|
||||
if self.cfg.episodic:
|
||||
termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.)
|
||||
action, _ = self.model.pi(z, task)
|
||||
return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user