modified envs

This commit is contained in:
NM512
2023-08-05 21:10:16 +09:00
parent a6ad132198
commit eb14e2488b
4 changed files with 7 additions and 18 deletions

View File

@@ -19,7 +19,6 @@ class Crafter:
"image": gym.spaces.Box(
0, 255, self._env.observation_space.shape, dtype=np.uint8
),
"reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"is_first": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_last": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
"is_terminal": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.uint8),
@@ -50,7 +49,6 @@ class Crafter:
}
obs = {
"image": image,
"reward": reward,
"is_first": False,
"is_last": done,
"is_terminal": info["discount"] == 0,

View File

@@ -35,7 +35,6 @@ class MemoryMaze:
return gym.spaces.Dict(
{
**spaces,
"reward": gym.spaces.Box(-np.inf, np.inf, (), dtype=np.float32),
"is_first": gym.spaces.Box(0, 1, (), dtype=bool),
"is_last": gym.spaces.Box(0, 1, (), dtype=bool),
"is_terminal": gym.spaces.Box(0, 1, (), dtype=bool),
@@ -52,7 +51,6 @@ class MemoryMaze:
obs, reward, done, info = self._env.step(action)
if not self._obs_is_dict:
obs = {self._obs_key: obs}
obs["reward"] = reward
obs["is_first"] = False
obs["is_last"] = done
obs["is_terminal"] = info.get("is_terminal", False)
@@ -62,7 +60,6 @@ class MemoryMaze:
obs = self._env.reset()
if not self._obs_is_dict:
obs = {self._obs_key: obs}
obs["reward"] = 0.0
obs["is_first"] = True
obs["is_last"] = False
obs["is_terminal"] = False

View File

@@ -20,11 +20,11 @@ class MinecraftWood:
HealthReward(),
]
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
super().__init__(env)
def step(self, action):
obs, reward, done, info = self.env.step(action)
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
obs["reward"] = reward
return obs, reward, done, info
@@ -34,6 +34,7 @@ class MinecraftClimb:
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
self._previous = None
self._health_reward = HealthReward()
super().__init__(env)
def step(self, action):
obs, reward, done, info = self.env.step(action)
@@ -43,7 +44,6 @@ class MinecraftClimb:
self._previous = height
reward = height - self._previous
reward += self._health_reward(obs)
obs["reward"] = reward
self._previous = height
return obs, reward, done, info
@@ -87,7 +87,6 @@ class MinecraftDiamond(gym.Wrapper):
def step(self, action):
obs, reward, done, info = self.env.step(action)
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
obs["reward"] = reward
return obs, reward, done, info
def reset(self):
@@ -131,7 +130,7 @@ class HealthReward:
return 0
reward = self.scale * (health - self.previous)
self.previous = health
return np.float32(reward)
return sum(reward)
BASIC_ACTIONS = {

View File

@@ -18,7 +18,7 @@ class MinecraftBase(gym.Env):
sticky_attack=30,
sticky_jump=10,
pitch_limit=(-60, 60),
logs=True,
logs=False,
):
if logs:
logging.basicConfig(level=logging.DEBUG)
@@ -41,7 +41,6 @@ class MinecraftBase(gym.Env):
if k.startswith("inventory/")
if k != "inventory/log2"
]
self._step = 0
self._max_inventory = None
self._equip_enum = self._env.observation_space["equipped_items"]["mainhand"][
"type"
@@ -75,7 +74,6 @@ class MinecraftBase(gym.Env):
"equipped": gym.spaces.Box(
-np.inf, np.inf, (len(self._equip_enum),), dtype=np.float32
),
"reward": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"health": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"hunger": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
"breath": gym.spaces.Box(-np.inf, np.inf, (1,), dtype=np.float32),
@@ -110,12 +108,11 @@ class MinecraftBase(gym.Env):
if "error" in info:
done = True
break
obs["is_first"] = False
obs["is_last"] = bool(done)
obs["is_terminal"] = bool(info.get("is_terminal", done))
obs["is_first"] = False
obs["is_last"] = bool(done)
obs["is_terminal"] = bool(info.get("is_terminal", done))
obs = self._obs(obs)
self._step += 1
assert "pov" not in obs, list(obs.keys())
return obs, reward, done, info
@@ -135,7 +132,6 @@ class MinecraftBase(gym.Env):
obs["is_terminal"] = False
obs = self._obs(obs)
self._step = 0
self._sticky_attack_counter = 0
self._sticky_jump_counter = 0
self._pitch = 0
@@ -166,7 +162,6 @@ class MinecraftBase(gym.Env):
"health": np.float32([obs["life_stats/life"]]) / 20,
"hunger": np.float32([obs["life_stats/food"]]) / 20,
"breath": np.float32([obs["life_stats/air"]]) / 300,
"reward": [0.0],
"is_first": obs["is_first"],
"is_last": obs["is_last"],
"is_terminal": obs["is_terminal"],