fix
This commit is contained in:
@@ -78,7 +78,9 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||
|
||||
# Check action space compatibility
|
||||
if cfg.get('action', 'continuous') == 'discrete':
|
||||
assert cfg.action in ['continuous', 'discrete'], \
|
||||
f'Invalid action space {cfg.action}. Must be one of ["continuous", "discrete"]'
|
||||
if cfg.action == 'discrete':
|
||||
assert not cfg.multitask, 'Discrete actions are not supported in multi-task settings.'
|
||||
|
||||
# Check torch.compile compatibility
|
||||
|
||||
@@ -74,7 +74,7 @@ def make_env(cfg):
|
||||
env = TensorWrapper(env)
|
||||
if cfg.get('obs', 'state') == 'rgb':
|
||||
env = PixelWrapper(cfg, env)
|
||||
if cfg.get('action', 'discrete'):
|
||||
if cfg.get('action', 'continuous') == 'discrete':
|
||||
env = DiscreteWrapper(env)
|
||||
try: # Dict
|
||||
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
||||
|
||||
Reference in New Issue
Block a user