learnable initial state options for RSSM

This commit is contained in:
NM512
2023-04-29 07:54:03 +09:00
parent 1328ff1088
commit 0eb66997fb
4 changed files with 60 additions and 12 deletions

View File

@@ -110,9 +110,10 @@ class Dreamer(nn.Module):
)
else:
latent, action = state
embed = self._wm.encoder(self._wm.preprocess(obs))
obs = self._wm.preprocess(obs)
embed = self._wm.encoder(obs)
latent, _ = self._wm.dynamics.obs_step(
latent, action, embed, self._config.collect_dyn_sample
latent, action, embed, obs["is_first"], self._config.collect_dyn_sample
)
if self._config.eval_state_mean:
latent["stoch"] = latent["mean"]