diff --git a/configs.yaml b/configs.yaml index 6311333..33b6efd 100644 --- a/configs.yaml +++ b/configs.yaml @@ -172,9 +172,21 @@ atari100k: imag_gradient: 'reinforce' time_limit: 108000 + debug: debug: True pretrain: 1 prefill: 1 batch_size: 10 batch_length: 20 + +MemoryMaze: + actor_dist: 'onehot' + imag_gradient: 'reinforce' + task: '9x9' + steps: 1e6 + action_repeat: 2 + + + + diff --git a/dreamer.py b/dreamer.py index c32d66f..fec590f 100644 --- a/dreamer.py +++ b/dreamer.py @@ -211,6 +211,17 @@ def make_env(config, logger, mode, train_eps, eval_eps): task, mode if "train" in mode else "test", config.action_repeat ) env = wrappers.OneHotAction(env) + elif suite == "MemoryMaze": + import gym + if task == '9x9': + env = gym.make('memory_maze:MemoryMaze-9x9-v0') + elif task == '15x15': + env = gym.make('memory_maze:MemoryMaze-15x15-v0') + else: + raise NotImplementedError(suite) + from envs.memorymaze import MemoryMaze + env = MemoryMaze(env) + env = wrappers.OneHotAction(env) elif suite == "crafter": import envs.crafter as crafter diff --git a/envs/memorymaze.py b/envs/memorymaze.py new file mode 100644 index 0000000..a194368 --- /dev/null +++ b/envs/memorymaze.py @@ -0,0 +1,87 @@ +import atexit +import os +import sys + +import cloudpickle +import gym +import numpy as np + +###from tf dreamerv2 code + +class MemoryMaze: + + def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)): + self._env = env + self._obs_is_dict = hasattr(self._env.observation_space, 'spaces') + self._act_is_dict = hasattr(self._env.action_space, 'spaces') + self._obs_key = obs_key + self._act_key = act_key + self._size = size + self._gray = False + + def __getattr__(self, name): + if name.startswith('__'): + raise AttributeError(name) + try: + return getattr(self._env, name) + except AttributeError: + raise ValueError(name) + + @property + def obs_space(self): + if self._obs_is_dict: + spaces = self._env.observation_space.spaces.copy() + else: + spaces = {self._obs_key: self._env.observation_space} + return { + **spaces, + 'reward': gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32), + 'is_first': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_last': gym.spaces.Box(0, 1, (), dtype=np.bool), + 'is_terminal': gym.spaces.Box(0, 1, (), dtype=np.bool), + } + + @property + def act_space(self): + if self._act_is_dict: + return self._env.action_space.spaces.copy() + else: + return {self._act_key: self._env.action_space} + + @property + def observation_space(self): + img_shape = self._size + ((1,) if self._gray else (3,)) + return gym.spaces.Dict( + { + "image": gym.spaces.Box(0, 255, img_shape, np.uint8), + } + ) + + @property + def action_space(self): + space = self._env.action_space + space.discrete = True + return space + + def step(self, action): + # if not self._act_is_dict: + # action = action[self._act_key] + obs, reward, done, info = self._env.step(action) + if not self._obs_is_dict: + obs = {self._obs_key: obs} + # obs['reward'] = float(reward) + obs['is_first'] = False + obs['is_last'] = done + obs['is_terminal'] = info.get('is_terminal', False) + return obs, reward, done, info + + def reset(self): + obs = self._env.reset() + if not self._obs_is_dict: + obs = {self._obs_key: obs} + obs['reward'] = 0.0 + obs['is_first'] = True + obs['is_last'] = False + obs['is_terminal'] = False + return obs +