modified envs

This commit is contained in:
NM512
2023-08-05 21:10:16 +09:00
parent a6ad132198
commit eb14e2488b
4 changed files with 7 additions and 18 deletions

View File

@@ -20,11 +20,11 @@ class MinecraftWood:
HealthReward(),
]
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
super().__init__(env)
def step(self, action):
obs, reward, done, info = self.env.step(action)
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
obs["reward"] = reward
return obs, reward, done, info
@@ -34,6 +34,7 @@ class MinecraftClimb:
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
self._previous = None
self._health_reward = HealthReward()
super().__init__(env)
def step(self, action):
obs, reward, done, info = self.env.step(action)
@@ -43,7 +44,6 @@ class MinecraftClimb:
self._previous = height
reward = height - self._previous
reward += self._health_reward(obs)
obs["reward"] = reward
self._previous = height
return obs, reward, done, info
@@ -87,7 +87,6 @@ class MinecraftDiamond(gym.Wrapper):
def step(self, action):
obs, reward, done, info = self.env.step(action)
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
obs["reward"] = reward
return obs, reward, done, info
def reset(self):
@@ -131,7 +130,7 @@ class HealthReward:
return 0
reward = self.scale * (health - self.previous)
self.previous = health
return np.float32(reward)
return sum(reward)
BASIC_ACTIONS = {