diff --git a/envs/minecraft_base.py b/envs/minecraft_base.py index e3ca82d..d7f0aba 100644 --- a/envs/minecraft_base.py +++ b/envs/minecraft_base.py @@ -163,10 +163,10 @@ class MinecraftBase(gym.Env): "inventory": inventory, "inventory_max": self._max_inventory.copy(), "equipped": equipped, - "health": np.float32(obs["life_stats/life"] / 20), - "hunger": np.float32(obs["life_stats/food"] / 20), - "breath": np.float32(obs["life_stats/air"] / 300), - "reward": 0.0, + "health": np.float32([obs["life_stats/life"]]) / 20, + "hunger": np.float32([obs["life_stats/food"]]) / 20, + "breath": np.float32([obs["life_stats/air"]]) / 300, + "reward": [0.0], "is_first": obs["is_first"], "is_last": obs["is_last"], "is_terminal": obs["is_terminal"], diff --git a/models.py b/models.py index 2c00597..b76d842 100644 --- a/models.py +++ b/models.py @@ -174,20 +174,19 @@ class WorldModel(nn.Module): post = {k: v.detach() for k, v in post.items()} return post, context, metrics + # this function is called during both rollout and training def preprocess(self, obs): obs = obs.copy() obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5 - # (batch_size, batch_length) -> (batch_size, batch_length, 1) - obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1) if "discount" in obs: obs["discount"] *= self._config.discount # (batch_size, batch_length) -> (batch_size, batch_length, 1) obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1) - if "is_terminal" in obs: - # this label is necessary to train cont_head - obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1) - else: - raise ValueError('"is_terminal" was not found in observation.') + # 'is_first' is necesarry to initialize hidden state at training + assert "is_first" in obs + # 'is_terminal' is necesarry to train cont_head + assert "is_terminal" in obs + obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1) obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()} return obs