From dc6720d3223e3c70b2a880db41e7d7f847e45608 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Mon, 11 Nov 2024 18:20:09 -0800 Subject: [PATCH] fix --- tdmpc2/common/parser.py | 4 +++- tdmpc2/envs/__init__.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index e931cd5..90196a3 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -78,7 +78,9 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) # Check action space compatibility - if cfg.get('action', 'continuous') == 'discrete': + assert cfg.action in ['continuous', 'discrete'], \ + f'Invalid action space {cfg.action}. Must be one of ["continuous", "discrete"]' + if cfg.action == 'discrete': assert not cfg.multitask, 'Discrete actions are not supported in multi-task settings.' # Check torch.compile compatibility diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 3b3f91c..7da75d3 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -74,7 +74,7 @@ def make_env(cfg): env = TensorWrapper(env) if cfg.get('obs', 'state') == 'rgb': env = PixelWrapper(cfg, env) - if cfg.get('action', 'discrete'): + if cfg.get('action', 'continuous') == 'discrete': env = DiscreteWrapper(env) try: # Dict cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}