From c694d286f04ca4133cc160227a5575913a9177d1 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 10 Nov 2024 12:25:43 -0800 Subject: [PATCH] add assertion for compile=true compatibility --- docker/environment.yaml | 9 +++++---- tdmpc2/common/parser.py | 5 +++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/docker/environment.yaml b/docker/environment.yaml index 9da54e0..87cab71 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -14,6 +14,7 @@ dependencies: - torchvision - pip: - absl-py==2.1.0 + - "cython<3" - dm-control==1.0.8 - glfw==2.7.0 - ffmpeg==1.4 @@ -23,6 +24,8 @@ dependencies: - hydra-core==1.3.2 - hydra-submitit-launcher==1.2.0 - submitit==1.5.1 + - setuptools==65.5.0 + - patchelf==0.17.2.1 - omegaconf==2.3.0 - moviepy==1.0.3 - mujoco==2.3.1 @@ -34,13 +37,11 @@ dependencies: - tqdm==4.66.4 - pandas==2.0.3 - wandb==0.17.4 - - matplotlib==3.7.5 - - seaborn==0.13.2 - - gpustat==1.1.1 + - wheel==0.38.0 #################### # Gym: # (unmaintained but required for maniskill2/meta-world) - - gym==0.21.0 + # - gym==0.21.0 #################### # ManiSkill2: # (requires gym==0.21.0 which occasionally breaks) diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index a8d9f25..e162eac 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -77,4 +77,9 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.task_dim = 0 cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) + # Check torch.compile compatibility + if cfg.get('compile', False): + assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.' + assert not cfg.multitask, 'torch.compile does not support multitask training at the moment.' + return cfg_to_dataclass(cfg)