added log for inventory items in minecraft

This commit is contained in:
NM512
2023-08-16 15:52:33 +09:00
parent 99dc4e4ed1
commit 68096d1f62
3 changed files with 42 additions and 17 deletions

View File

@@ -66,20 +66,22 @@ class MinecraftDiamond(gym.Wrapper):
"place_furnace": dict(place="furnace"),
"smelt_iron_ingot": dict(nearbySmelt="iron_ingot"),
}
self.rewards = [
CollectReward("log", once=1),
CollectReward("planks", once=1),
CollectReward("stick", once=1),
CollectReward("crafting_table", once=1),
CollectReward("wooden_pickaxe", once=1),
CollectReward("cobblestone", once=1),
CollectReward("stone_pickaxe", once=1),
CollectReward("iron_ore", once=1),
CollectReward("furnace", once=1),
CollectReward("iron_ingot", once=1),
CollectReward("iron_pickaxe", once=1),
CollectReward("diamond", once=1),
HealthReward(),
self.items = [
"log",
"planks",
"stick",
"crafting_table",
"wooden_pickaxe",
"cobblestone",
"stone_pickaxe",
"iron_ore",
"furnace",
"iron_ingot",
"iron_pickaxe",
"diamond",
]
self.rewards = [CollectReward(item, once=1) for item in self.items] + [
HealthReward()
]
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
super().__init__(env)
@@ -87,12 +89,24 @@ class MinecraftDiamond(gym.Wrapper):
def step(self, action):
obs, reward, done, info = self.env.step(action)
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
# restrict log for memory save
obs = {
k: v
for k, v in obs.items()
if "log" not in k or k.split("/")[-1] in self.items
}
return obs, reward, done, info
def reset(self):
obs = self.env.reset()
# called for reset of reward calculations
_ = sum([fn(obs, self.env.inventory) for fn in self.rewards])
# restrict log for memory save
obs = {
k: v
for k, v in obs.items()
if "log" not in k or k.split("/")[-1] in self.items
}
return obs