changed treatment of obs shape in minecraft
This commit is contained in:
13
models.py
13
models.py
@@ -174,20 +174,19 @@ class WorldModel(nn.Module):
|
||||
post = {k: v.detach() for k, v in post.items()}
|
||||
return post, context, metrics
|
||||
|
||||
# this function is called during both rollout and training
|
||||
def preprocess(self, obs):
|
||||
obs = obs.copy()
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
obs["reward"] = torch.Tensor(obs["reward"]).unsqueeze(-1)
|
||||
if "discount" in obs:
|
||||
obs["discount"] *= self._config.discount
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
|
||||
if "is_terminal" in obs:
|
||||
# this label is necessary to train cont_head
|
||||
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||
else:
|
||||
raise ValueError('"is_terminal" was not found in observation.')
|
||||
# 'is_first' is necesarry to initialize hidden state at training
|
||||
assert "is_first" in obs
|
||||
# 'is_terminal' is necesarry to train cont_head
|
||||
assert "is_terminal" in obs
|
||||
obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
|
||||
obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
|
||||
return obs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user