update readme + clean up

This commit is contained in:
Nicklas Hansen
2025-04-15 16:32:15 -07:00
parent eece80123d
commit 7942e9082b
3 changed files with 8 additions and 11 deletions

View File

@@ -12,9 +12,9 @@ Official implementation of
----
**Announcement: training just got ~4.5x faster!**
**Announcement (Apr 2025): support for episodic tasks!**
Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install recent `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility!
We have added support for episodic RL (tasks with terminations) in the latest release. This functionality can be enabled with `episodic=true` but remains disabled by default to ensure reproducibility of results across releases.
----
@@ -74,7 +74,7 @@ See `docker/Dockerfile` for installation instructions if you do not already have
## Supported tasks
This codebase currently supports **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite**. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain:
This codebase provides support for all **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite** used in our paper. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain:
| domain | task
| --- | --- |
@@ -87,9 +87,9 @@ This codebase currently supports **104** continuous control tasks from **DMContr
| myosuite | myo-key-turn
| myosuite | myo-key-turn-hard
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. We also provide basic support for other MuJoCo/Box2d Gymnasium tasks; refer to the `envs` directory for a list of tasks. It should be relatively straightforward to add support for custom tasks by following the examples in `envs`.
**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.
**Note:** we also provide support for image observations in the DMControl tasks. Use argument `obs=rgb` if you wish to train visual policies.
## Example usage
@@ -121,8 +121,6 @@ $ python train.py task=walker-walk obs=rgb
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
**As of Jan 7, 2024 the TD-MPC2 codebase also supports multi-GPU training for multi-task offline RL experiments**; use branch `distributed` and argument `world_size=N` to train on `N` GPUs. We cannot guarantee that distributed training will yield the same results, but they appear to be similar based on our limited testing.
----
## Citation

View File

@@ -71,6 +71,7 @@ enable_wandb: true
save_csv: true
# misc
compile: true
save_video: true
save_agent: true
seed: 1
@@ -88,6 +89,3 @@ action_dims: ???
episode_lengths: ???
seed_steps: ???
bin_size: ???
# speedups
compile: false

View File

@@ -131,7 +131,8 @@ class TDMPC2(torch.nn.Module):
G = G + discount * (1-termination) * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update
termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.)
if self.cfg.episodic:
termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.)
action, _ = self.model.pi(z, task)
return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')