learnable initial state options for RSSM

This commit is contained in:
NM512
2023-04-29 07:54:03 +09:00
parent 1328ff1088
commit 0eb66997fb
4 changed files with 60 additions and 12 deletions

View File

@@ -28,6 +28,7 @@ class RSSM(nn.Module):
min_std=0.1,
cell="gru",
unimix_ratio=0.01,
initial="learned",
num_actions=None,
embed=None,
device=None,
@@ -48,6 +49,7 @@ class RSSM(nn.Module):
self._std_act = std_act
self._temp_post = temp_post
self._unimix_ratio = unimix_ratio
self._initial = initial
self._embed = embed
self._device = device
@@ -112,6 +114,12 @@ class RSSM(nn.Module):
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._obs_stat_layer.apply(tools.weight_init)
if self._initial == "learned":
self.W = torch.nn.Parameter(
torch.zeros((1, self._deter), device=torch.device(self._device)),
requires_grad=True,
)
def initial(self, batch_size):
deter = torch.zeros(batch_size, self._deter).to(self._device)
if self._discrete:
@@ -131,19 +139,27 @@ class RSSM(nn.Module):
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
deter=deter,
)
return state
if self._initial == "zeros":
return state
elif self._initial == "learned":
state["deter"] = torch.tanh(self.W).repeat(batch_size, 1)
state["stoch"] = self.get_stoch(state["deter"])
return state
else:
raise NotImplementedError(self._initial)
def observe(self, embed, action, state=None):
def observe(self, embed, action, is_first, state=None):
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
if state is None:
state = self.initial(action.shape[0])
# (batch, time, ch) -> (time, batch, ch)
embed, action = swap(embed), swap(action)
embed, action, is_first = swap(embed), swap(action), swap(is_first)
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
post, prior = tools.static_scan(
lambda prev_state, prev_act, embed: self.obs_step(
prev_state[0], prev_act, embed
lambda prev_state, prev_act, embed, is_first: self.obs_step(
prev_state[0], prev_act, embed, is_first
),
(action, embed),
(action, embed, is_first),
(state, state),
)
@@ -184,10 +200,22 @@ class RSSM(nn.Module):
)
return dist
def obs_step(self, prev_state, prev_action, embed, sample=True):
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
if torch.sum(is_first) > 0:
is_first = is_first[:, None]
prev_action *= 1.0 - is_first
init_state = self.initial(len(is_first))
for key, val in prev_state.items():
is_first_r = torch.reshape(
is_first,
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
)
val = val * (1.0 - is_first_r) + init_state[key] * is_first_r
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
@@ -242,6 +270,12 @@ class RSSM(nn.Module):
prior = {"stoch": stoch, "deter": deter, **stats}
return prior
def get_stoch(self, deter):
x = self._img_out_layers(deter)
stats = self._suff_stats_layer("ims", x)
dist = self.get_dist(stats)
return dist.mode()
def _suff_stats_layer(self, name, x):
if self._discrete:
if name == "ims":
@@ -435,6 +469,7 @@ class DenseHead(nn.Module):
dist="normal",
std=1.0,
outscale=1.0,
device="cuda",
):
super(DenseHead, self).__init__()
self._shape = (shape,) if isinstance(shape, int) else shape
@@ -446,6 +481,7 @@ class DenseHead(nn.Module):
self._norm = norm
self._dist = dist
self._std = std
self._device = device
layers = []
for index in range(self._layers):
@@ -491,7 +527,7 @@ class DenseHead(nn.Module):
)
)
if self._dist == "twohot_symlog":
return tools.TwoHotDistSymlog(logits=mean)
return tools.TwoHotDistSymlog(logits=mean, device=self._device)
raise NotImplementedError(self._dist)