From 7942e9082b792eca540e0274bc572c7bc066f36b Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Tue, 15 Apr 2025 16:32:15 -0700 Subject: [PATCH] update readme + clean up --- README.md | 12 +++++------- tdmpc2/config.yaml | 4 +--- tdmpc2/tdmpc2.py | 3 ++- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 0e595f1..9a3fd3f 100755 --- a/README.md +++ b/README.md @@ -12,9 +12,9 @@ Official implementation of ---- -**Announcement: training just got ~4.5x faster!** +**Announcement (Apr 2025): support for episodic tasks!** -Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install recent `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility! +We have added support for episodic RL (tasks with terminations) in the latest release. This functionality can be enabled with `episodic=true` but remains disabled by default to ensure reproducibility of results across releases. ---- @@ -74,7 +74,7 @@ See `docker/Dockerfile` for installation instructions if you do not already have ## Supported tasks -This codebase currently supports **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite**. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain: +This codebase provides support for all **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite** used in our paper. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain: | domain | task | --- | --- | @@ -87,9 +87,9 @@ This codebase currently supports **104** continuous control tasks from **DMContr | myosuite | myo-key-turn | myosuite | myo-key-turn-hard -which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. +which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. We also provide basic support for other MuJoCo/Box2d Gymnasium tasks; refer to the `envs` directory for a list of tasks. It should be relatively straightforward to add support for custom tasks by following the examples in `envs`. -**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies. +**Note:** we also provide support for image observations in the DMControl tasks. Use argument `obs=rgb` if you wish to train visual policies. ## Example usage @@ -121,8 +121,6 @@ $ python train.py task=walker-walk obs=rgb We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments. -**As of Jan 7, 2024 the TD-MPC2 codebase also supports multi-GPU training for multi-task offline RL experiments**; use branch `distributed` and argument `world_size=N` to train on `N` GPUs. We cannot guarantee that distributed training will yield the same results, but they appear to be similar based on our limited testing. - ---- ## Citation diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index ac6ca5d..ff15dbb 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -71,6 +71,7 @@ enable_wandb: true save_csv: true # misc +compile: true save_video: true save_agent: true seed: 1 @@ -88,6 +89,3 @@ action_dims: ??? episode_lengths: ??? seed_steps: ??? bin_size: ??? - -# speedups -compile: false diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 1df1b77..7357d48 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -131,7 +131,8 @@ class TDMPC2(torch.nn.Module): G = G + discount * (1-termination) * reward discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount = discount * discount_update - termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.) + if self.cfg.episodic: + termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.) action, _ = self.model.pi(z, task) return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')