changed treatment of obs shape in minecraft

This commit is contained in:
NM512
2023-08-03 08:12:44 +09:00
parent d94a719421
commit 3f6659d365
2 changed files with 10 additions and 11 deletions

View File

@@ -174,20 +174,19 @@ class WorldModel(nn.Module):
post = {k: v.detach() for k, v in post.items()}
return post, context, metrics
# this function is called during both rollout and training
def preprocess(self, obs):
obs = obs.copy()
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1)
if "discount" in obs:
obs["discount"] *= self._config.discount
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
if "is_terminal" in obs:
# this label is necessary to train cont_head
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
else:
raise ValueError('"is_terminal" was not found in observation.')
# 'is_first' is necesarry to initialize hidden state at training
assert "is_first" in obs
# 'is_terminal' is necesarry to train cont_head
assert "is_terminal" in obs
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
return obs