learnable initial state options for RSSM
This commit is contained in:
52
networks.py
52
networks.py
@@ -28,6 +28,7 @@ class RSSM(nn.Module):
|
||||
min_std=0.1,
|
||||
cell="gru",
|
||||
unimix_ratio=0.01,
|
||||
initial="learned",
|
||||
num_actions=None,
|
||||
embed=None,
|
||||
device=None,
|
||||
@@ -48,6 +49,7 @@ class RSSM(nn.Module):
|
||||
self._std_act = std_act
|
||||
self._temp_post = temp_post
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._initial = initial
|
||||
self._embed = embed
|
||||
self._device = device
|
||||
|
||||
@@ -112,6 +114,12 @@ class RSSM(nn.Module):
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
|
||||
if self._initial == "learned":
|
||||
self.W = torch.nn.Parameter(
|
||||
torch.zeros((1, self._deter), device=torch.device(self._device)),
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
def initial(self, batch_size):
|
||||
deter = torch.zeros(batch_size, self._deter).to(self._device)
|
||||
if self._discrete:
|
||||
@@ -131,19 +139,27 @@ class RSSM(nn.Module):
|
||||
stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
|
||||
deter=deter,
|
||||
)
|
||||
return state
|
||||
if self._initial == "zeros":
|
||||
return state
|
||||
elif self._initial == "learned":
|
||||
state["deter"] = torch.tanh(self.W).repeat(batch_size, 1)
|
||||
state["stoch"] = self.get_stoch(state["deter"])
|
||||
return state
|
||||
else:
|
||||
raise NotImplementedError(self._initial)
|
||||
|
||||
def observe(self, embed, action, state=None):
|
||||
def observe(self, embed, action, is_first, state=None):
|
||||
swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
|
||||
if state is None:
|
||||
state = self.initial(action.shape[0])
|
||||
# (batch, time, ch) -> (time, batch, ch)
|
||||
embed, action = swap(embed), swap(action)
|
||||
embed, action, is_first = swap(embed), swap(action), swap(is_first)
|
||||
# prev_state[0] means selecting posterior of return(posterior, prior) from obs_step
|
||||
post, prior = tools.static_scan(
|
||||
lambda prev_state, prev_act, embed: self.obs_step(
|
||||
prev_state[0], prev_act, embed
|
||||
lambda prev_state, prev_act, embed, is_first: self.obs_step(
|
||||
prev_state[0], prev_act, embed, is_first
|
||||
),
|
||||
(action, embed),
|
||||
(action, embed, is_first),
|
||||
(state, state),
|
||||
)
|
||||
|
||||
@@ -184,10 +200,22 @@ class RSSM(nn.Module):
|
||||
)
|
||||
return dist
|
||||
|
||||
def obs_step(self, prev_state, prev_action, embed, sample=True):
|
||||
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
|
||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
|
||||
if torch.sum(is_first) > 0:
|
||||
is_first = is_first[:, None]
|
||||
prev_action *= 1.0 - is_first
|
||||
init_state = self.initial(len(is_first))
|
||||
for key, val in prev_state.items():
|
||||
is_first_r = torch.reshape(
|
||||
is_first,
|
||||
is_first.shape + (1,) * (len(val.shape) - len(is_first.shape)),
|
||||
)
|
||||
val = val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||
|
||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||
if self._shared:
|
||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
||||
@@ -242,6 +270,12 @@ class RSSM(nn.Module):
|
||||
prior = {"stoch": stoch, "deter": deter, **stats}
|
||||
return prior
|
||||
|
||||
def get_stoch(self, deter):
|
||||
x = self._img_out_layers(deter)
|
||||
stats = self._suff_stats_layer("ims", x)
|
||||
dist = self.get_dist(stats)
|
||||
return dist.mode()
|
||||
|
||||
def _suff_stats_layer(self, name, x):
|
||||
if self._discrete:
|
||||
if name == "ims":
|
||||
@@ -435,6 +469,7 @@ class DenseHead(nn.Module):
|
||||
dist="normal",
|
||||
std=1.0,
|
||||
outscale=1.0,
|
||||
device="cuda",
|
||||
):
|
||||
super(DenseHead, self).__init__()
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
@@ -446,6 +481,7 @@ class DenseHead(nn.Module):
|
||||
self._norm = norm
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
self._device = device
|
||||
|
||||
layers = []
|
||||
for index in range(self._layers):
|
||||
@@ -491,7 +527,7 @@ class DenseHead(nn.Module):
|
||||
)
|
||||
)
|
||||
if self._dist == "twohot_symlog":
|
||||
return tools.TwoHotDistSymlog(logits=mean)
|
||||
return tools.TwoHotDistSymlog(logits=mean, device=self._device)
|
||||
raise NotImplementedError(self._dist)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user