This commit is contained in:
Nicklas Hansen
2024-11-11 18:20:09 -08:00
parent dee034070e
commit dc6720d322
2 changed files with 4 additions and 2 deletions

View File

@@ -78,7 +78,9 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
# Check action space compatibility
if cfg.get('action', 'continuous') == 'discrete':
assert cfg.action in ['continuous', 'discrete'], \
f'Invalid action space {cfg.action}. Must be one of ["continuous", "discrete"]'
if cfg.action == 'discrete':
assert not cfg.multitask, 'Discrete actions are not supported in multi-task settings.'
# Check torch.compile compatibility

View File

@@ -74,7 +74,7 @@ def make_env(cfg):
env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)
if cfg.get('action', 'discrete'):
if cfg.get('action', 'continuous') == 'discrete':
env = DiscreteWrapper(env)
try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}