From 8e005afde55da3c2cdec456218c055501392d067 Mon Sep 17 00:00:00 2001 From: zdx <179363811@qq.com> Date: Sun, 18 Jun 2023 09:16:32 +0800 Subject: [PATCH] mem maze env ok 1.2 --- configs.yaml | 6 +++-- dreamer.py | 12 ++++----- envs/{memmazeEnv.py => memorymaze.py} | 6 ++--- envs/wrappers.py | 38 ++------------------------- tools.py | 2 +- 5 files changed, 15 insertions(+), 49 deletions(-) rename envs/{memmazeEnv.py => memorymaze.py} (95%) diff --git a/configs.yaml b/configs.yaml index 4750ee3..18e96a1 100644 --- a/configs.yaml +++ b/configs.yaml @@ -156,8 +156,10 @@ debug: batch_size: 10 batch_length: 20 -mazegym: - task: '9' +MemoryMaze: + actor_dist: 'onehot' + imag_gradient: 'reinforce' + task: '9x9' steps: 1e6 action_repeat: 2 diff --git a/dreamer.py b/dreamer.py index 3e90050..24750d1 100644 --- a/dreamer.py +++ b/dreamer.py @@ -210,17 +210,17 @@ def make_env(config, logger, mode, train_eps, eval_eps): task, mode if "train" in mode else "test", config.action_repeat ) env = wrappers.OneHotAction(env) - elif suite == "mazegym": + elif suite == "MemoryMaze": import gym - if task == '9': + if task == '9x9': env = gym.make('memory_maze:MemoryMaze-9x9-v0') - elif task == '15': + elif task == '15x15': env = gym.make('memory_maze:MemoryMaze-15x15-v0') else: raise NotImplementedError(suite) - from envs.memmazeEnv import MZGymWrapper - env = MZGymWrapper(env) - env = wrappers.OneHotAction2(env) + from envs.memorymaze import MemoryMaze + env = MemoryMaze(env) + env = wrappers.OneHotAction(env) else: raise NotImplementedError(suite) env = wrappers.TimeLimit(env, config.time_limit) diff --git a/envs/memmazeEnv.py b/envs/memorymaze.py similarity index 95% rename from envs/memmazeEnv.py rename to envs/memorymaze.py index 980f805..a194368 100644 --- a/envs/memmazeEnv.py +++ b/envs/memorymaze.py @@ -1,8 +1,6 @@ import atexit import os import sys -import threading -import traceback import cloudpickle import gym @@ -10,7 +8,7 @@ import numpy as np ###from tf dreamerv2 code -class MZGymWrapper: +class MemoryMaze: def __init__(self, env, obs_key='image', act_key='action', size=(64, 64)): self._env = env @@ -74,7 +72,7 @@ class MZGymWrapper: # obs['reward'] = float(reward) obs['is_first'] = False obs['is_last'] = done - obs['is_terminal'] = info.get('is_terminal', done) + obs['is_terminal'] = info.get('is_terminal', False) return obs, reward, done, info def reset(self): diff --git a/envs/wrappers.py b/envs/wrappers.py index 03ff649..1a4a58b 100644 --- a/envs/wrappers.py +++ b/envs/wrappers.py @@ -77,8 +77,8 @@ class CollectDataset: dtype = {16: np.int16, 32: np.int32, 64: np.int64}[self._precision] elif np.issubdtype(value.dtype, np.uint8): dtype = np.uint8 - elif np.issubdtype(value.dtype, np.bool_): - dtype = np.bool_ + elif np.issubdtype(value.dtype, np.bool): + dtype = np.bool else: raise NotImplementedError(value.dtype) return value.astype(dtype) @@ -168,40 +168,6 @@ class OneHotAction: reference[index] = 1.0 return reference -class OneHotAction2: - def __init__(self, env): - assert isinstance(env.action_space, gym.spaces.Discrete) - self._env = env - self._random = np.random.RandomState() - - def __getattr__(self, name): - return getattr(self._env, name) - - @property - def action_space(self): - shape = (self._env.action_space.n,) - space = gym.spaces.Box(low=0, high=1, shape=shape, dtype=np.float32) - space.sample = self._sample_action - space.discrete = True - return space - - def step(self, action): - index = np.argmax(action).astype(int) - reference = np.zeros_like(action) - reference[index] = 1 - # if not np.allclose(reference, action): - # raise ValueError(f"Invalid one-hot action:\n{action}") - return self._env.step(index) - - def reset(self): - return self._env.reset() - - def _sample_action(self): - actions = self._env.action_space.n - index = self._random.randint(0, actions) - reference = np.zeros(actions, dtype=np.float32) - reference[index] = 1.0 - return reference class RewardObs: def __init__(self, env): diff --git a/tools.py b/tools.py index 752b786..bc46903 100644 --- a/tools.py +++ b/tools.py @@ -127,7 +127,7 @@ def simulate(agent, envs, steps=0, episodes=0, state=None): # Initialize or unpack simulation state. if state is None: step, episode = 0, 0 - done = np.ones(len(envs), np.bool_) + done = np.ones(len(envs), np.bool) length = np.zeros(len(envs), np.int32) obs = [None] * len(envs) agent_state = None