modified weight initialization
This commit is contained in:
@@ -179,7 +179,7 @@ class WorldModel(nn.Module):
|
||||
# this function is called during both rollout and training
|
||||
def preprocess(self, obs):
|
||||
obs = obs.copy()
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
|
||||
obs["image"] = torch.Tensor(obs["image"]) / 255.0
|
||||
if "discount" in obs:
|
||||
obs["discount"] *= self._config.discount
|
||||
# (batch_size, batch_length) -> (batch_size, batch_length, 1)
|
||||
@@ -209,8 +209,8 @@ class WorldModel(nn.Module):
|
||||
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
|
||||
# observed image is given until 5 steps
|
||||
model = torch.cat([recon[:, :5], openl], 1)
|
||||
truth = data["image"][:6] + 0.5
|
||||
model = model + 0.5
|
||||
truth = data["image"][:6]
|
||||
model = model
|
||||
error = (model - truth + 1.0) / 2.0
|
||||
|
||||
return torch.cat([truth, model, error], 2)
|
||||
|
||||
Reference in New Issue
Block a user