enable torch.compile for offline rl + rgb inputs

This commit is contained in:
Nicklas Hansen
2024-12-25 12:22:39 -08:00
parent e452ca7539
commit a19f91c0b5
2 changed files with 1 additions and 6 deletions

View File

@@ -14,7 +14,7 @@ Official implementation of
**Announcement: training just got ~4.5x faster!**
Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. `compile=true` is available in state-based online RL at the moment, and we expect to roll out support across all settings in the coming months. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility!
Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install recent `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility!
----

View File

@@ -77,9 +77,4 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
cfg.task_dim = 0
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
# Check torch.compile compatibility
if cfg.get('compile', False):
assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.'
assert not cfg.multitask, 'torch.compile does not support multitask training at the moment.'
return cfg_to_dataclass(cfg)