unified the place to initialize the latents
This commit is contained in:
15
dreamer.py
15
dreamer.py
@@ -59,15 +59,6 @@ class Dreamer(nn.Module):
|
||||
|
||||
def __call__(self, obs, reset, state=None, training=True):
|
||||
step = self._step
|
||||
if self._should_reset(step):
|
||||
state = None
|
||||
if state is not None and reset.any():
|
||||
mask = 1 - reset
|
||||
for key in state[0].keys():
|
||||
for i in range(state[0][key].shape[0]):
|
||||
state[0][key][i] *= mask[i]
|
||||
for i in range(len(state[1])):
|
||||
state[1][i] *= mask[i]
|
||||
if training:
|
||||
steps = (
|
||||
self._config.pretrain
|
||||
@@ -96,11 +87,7 @@ class Dreamer(nn.Module):
|
||||
|
||||
def _policy(self, obs, state, training):
|
||||
if state is None:
|
||||
batch_size = len(obs["image"])
|
||||
latent = self._wm.dynamics.initial(len(obs["image"]))
|
||||
action = torch.zeros((batch_size, self._config.num_actions)).to(
|
||||
self._config.device
|
||||
)
|
||||
latent = action = None
|
||||
else:
|
||||
latent, action = state
|
||||
obs = self._wm.preprocess(obs)
|
||||
|
||||
Reference in New Issue
Block a user