diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index e931cd5..90196a3 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -78,7 +78,9 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) # Check action space compatibility - if cfg.get('action', 'continuous') == 'discrete': + assert cfg.action in ['continuous', 'discrete'], \ + f'Invalid action space {cfg.action}. Must be one of ["continuous", "discrete"]' + if cfg.action == 'discrete': assert not cfg.multitask, 'Discrete actions are not supported in multi-task settings.' # Check torch.compile compatibility diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 3b3f91c..7da75d3 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -74,7 +74,7 @@ def make_env(cfg): env = TensorWrapper(env) if cfg.get('obs', 'state') == 'rgb': env = PixelWrapper(cfg, env) - if cfg.get('action', 'discrete'): + if cfg.get('action', 'continuous') == 'discrete': env = DiscreteWrapper(env) try: # Dict cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}