diff --git a/README.md b/README.md
index 7b660bc..4e37bf5 100644
--- a/README.md
+++ b/README.md
@@ -36,8 +36,7 @@ So far, the following benchmarks can be used for testing.

#### Crafter
-
-
+
## Acknowledgments
This code is heavily inspired by the following works:
diff --git a/envs/minecraft.py b/envs/minecraft.py
index 59f13e7..338d31f 100644
--- a/envs/minecraft.py
+++ b/envs/minecraft.py
@@ -66,20 +66,22 @@ class MinecraftDiamond(gym.Wrapper):
"place_furnace": dict(place="furnace"),
"smelt_iron_ingot": dict(nearbySmelt="iron_ingot"),
}
- self.rewards = [
- CollectReward("log", once=1),
- CollectReward("planks", once=1),
- CollectReward("stick", once=1),
- CollectReward("crafting_table", once=1),
- CollectReward("wooden_pickaxe", once=1),
- CollectReward("cobblestone", once=1),
- CollectReward("stone_pickaxe", once=1),
- CollectReward("iron_ore", once=1),
- CollectReward("furnace", once=1),
- CollectReward("iron_ingot", once=1),
- CollectReward("iron_pickaxe", once=1),
- CollectReward("diamond", once=1),
- HealthReward(),
+ self.items = [
+ "log",
+ "planks",
+ "stick",
+ "crafting_table",
+ "wooden_pickaxe",
+ "cobblestone",
+ "stone_pickaxe",
+ "iron_ore",
+ "furnace",
+ "iron_ingot",
+ "iron_pickaxe",
+ "diamond",
+ ]
+ self.rewards = [CollectReward(item, once=1) for item in self.items] + [
+ HealthReward()
]
env = minecraft_base.MinecraftBase(actions, *args, **kwargs)
super().__init__(env)
@@ -87,12 +89,24 @@ class MinecraftDiamond(gym.Wrapper):
def step(self, action):
obs, reward, done, info = self.env.step(action)
reward = sum([fn(obs, self.env.inventory) for fn in self.rewards])
+ # restrict log for memory save
+ obs = {
+ k: v
+ for k, v in obs.items()
+ if "log" not in k or k.split("/")[-1] in self.items
+ }
return obs, reward, done, info
def reset(self):
obs = self.env.reset()
# called for reset of reward calculations
_ = sum([fn(obs, self.env.inventory) for fn in self.rewards])
+ # restrict log for memory save
+ obs = {
+ k: v
+ for k, v in obs.items()
+ if "log" not in k or k.split("/")[-1] in self.items
+ }
return obs
diff --git a/tools.py b/tools.py
index 1cfed7d..d5da141 100644
--- a/tools.py
+++ b/tools.py
@@ -84,7 +84,10 @@ class Logger:
with (self._logdir / "metrics.jsonl").open("a") as f:
f.write(json.dumps({"step": step, **dict(scalars)}) + "\n")
for name, value in scalars:
- self._writer.add_scalar("scalars/" + name, value, step)
+ if "/" not in name:
+ self._writer.add_scalar("scalars/" + name, value, step)
+ else:
+ self._writer.add_scalar(name, value, step)
for name, value in self._images.items():
self._writer.add_image(name, value, step)
for name, value in self._videos.items():
@@ -203,6 +206,15 @@ def simulate(
length = len(cache[envs[i].id]["reward"]) - 1
score = float(np.array(cache[envs[i].id]["reward"]).sum())
video = cache[envs[i].id]["image"]
+ # record logs given from environments
+ for key in list(cache[envs[i].id].keys()):
+ if "log_" in key:
+ logger.scalar(
+ key, float(np.array(cache[envs[i].id][key]).sum())
+ )
+ # log items won't be used later
+ cache[envs[i].id].pop(key)
+
if not is_eval:
step_in_dataset = erase_over_episodes(cache, limit)
logger.scalar(f"dataset_size", step_in_dataset)