modified weight initialization

This commit is contained in:
NM512
2024-01-05 10:46:54 +09:00
parent 4fe9b29ebe
commit a9e85e8b7c
3 changed files with 61 additions and 40 deletions

View File

@@ -179,7 +179,7 @@ class WorldModel(nn.Module):
# this function is called during both rollout and training
def preprocess(self, obs):
obs = obs.copy()
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
obs["image"] = torch.Tensor(obs["image"]) / 255.0
if "discount" in obs:
obs["discount"] *= self._config.discount
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
@@ -209,8 +209,8 @@ class WorldModel(nn.Module):
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
# observed image is given until 5 steps
model = torch.cat([recon[:, :5], openl], 1)
truth = data["image"][:6] + 0.5
model = model + 0.5
truth = data["image"][:6]
model = model
error = (model - truth + 1.0) / 2.0
return torch.cat([truth, model, error], 2)