unified the place to initialize the latents

This commit is contained in:
NM512
2024-01-05 10:09:13 +09:00
parent 49d12baa48
commit e0f2017e28
3 changed files with 12 additions and 21 deletions

View File

@@ -202,7 +202,7 @@ class WorldModel(nn.Module):
]
reward_post = self.heads["reward"](self.dynamics.get_feat(states)).mode()[:6]
init = {k: v[:, -1] for k, v in states.items()}
prior = self.dynamics.imagine(data["action"][:6, 5:], init)
prior = self.dynamics.imagine_with_action(data["action"][:6, 5:], init)
openl = self.heads["decoder"](self.dynamics.get_feat(prior))["image"].mode()
reward_prior = self.heads["reward"](self.dynamics.get_feat(prior)).mode()
# observed image is given until 5 steps