unified the place to initialize the latents

This commit is contained in:
NM512
2024-01-05 10:09:13 +09:00
parent 49d12baa48
commit e0f2017e28
3 changed files with 12 additions and 21 deletions

View File

@@ -59,15 +59,6 @@ class Dreamer(nn.Module):
def __call__(self, obs, reset, state=None, training=True):
step = self._step
if self._should_reset(step):
state = None
if state is not None and reset.any():
mask = 1 - reset
for key in state[0].keys():
for i in range(state[0][key].shape[0]):
state[0][key][i] *= mask[i]
for i in range(len(state[1])):
state[1][i] *= mask[i]
if training:
steps = (
self._config.pretrain
@@ -96,11 +87,7 @@ class Dreamer(nn.Module):
def _policy(self, obs, state, training):
if state is None:
batch_size = len(obs["image"])
latent = self._wm.dynamics.initial(len(obs["image"]))
action = torch.zeros((batch_size, self._config.num_actions)).to(
self._config.device
)
latent = action = None
else:
latent, action = state
obs = self._wm.preprocess(obs)