unified the place to initialize the latents

This commit is contained in:
NM512
2024-01-05 10:09:13 +09:00
parent 49d12baa48
commit e0f2017e28
3 changed files with 12 additions and 21 deletions

View File

@@ -51,6 +51,7 @@ class RSSM(nn.Module):
self._temp_post = temp_post
self._unimix_ratio = unimix_ratio
self._initial = initial
self._num_actions = num_actions
self._embed = embed
self._device = device
@@ -151,8 +152,6 @@ class RSSM(nn.Module):
def observe(self, embed, action, is_first, state=None):
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(action.shape[0])
# (batch, time, ch) -> (time, batch, ch)
embed, action, is_first = swap(embed), swap(action), swap(is_first)
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
@@ -169,10 +168,8 @@ class RSSM(nn.Module):
prior = {k: swap(v) for k, v in prior.items()}
return post, prior
def imagine(self, action, state=None):
def imagine_with_action(self, action, state):
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(action.shape[0])
assert isinstance(state, dict), state
action = action
action = swap(action)
@@ -206,7 +203,14 @@ class RSSM(nn.Module):
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
if torch.sum(is_first) > 0:
# initialize all prev_state
if prev_state == None or torch.sum(is_first) == len(is_first):
prev_state = self.initial(len(is_first))
prev_action = torch.zeros((len(is_first), self._num_actions)).to(
self._device
)
# overwrite the prev_state only where is_first=True
elif torch.sum(is_first) > 0:
is_first = is_first[:, None]
prev_action *= 1.0 - is_first
init_state = self.initial(len(is_first))