modified weight initialization

This commit is contained in:
NM512
2024-01-05 10:46:54 +09:00
parent 4fe9b29ebe
commit a9e85e8b7c
3 changed files with 61 additions and 40 deletions

View File

@@ -68,9 +68,8 @@ class RSSM(nn.Module):
inp_layers.append(act())
if i == 0:
inp_dim = self._hidden
self._inp_layers = nn.Sequential(*inp_layers)
self._inp_layers.apply(tools.weight_init)
self._img_in_layers = nn.Sequential(*inp_layers)
self._img_in_layers.apply(tools.weight_init)
if cell == "gru":
self._cell = GRUCell(self._hidden, self._deter)
self._cell.apply(tools.weight_init)
@@ -106,15 +105,17 @@ class RSSM(nn.Module):
self._obs_out_layers.apply(tools.weight_init)
if self._discrete:
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
self._ims_stat_layer.apply(tools.weight_init)
self._imgs_stat_layer = nn.Linear(
self._hidden, self._stoch * self._discrete
)
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
self._obs_stat_layer.apply(tools.weight_init)
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
else:
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._ims_stat_layer.apply(tools.weight_init)
self._imgs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._obs_stat_layer.apply(tools.weight_init)
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
if self._initial == "learned":
self.W = torch.nn.Parameter(
@@ -260,7 +261,7 @@ class RSSM(nn.Module):
else:
x = torch.cat([prev_stoch, prev_action], -1)
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
x = self._inp_layers(x)
x = self._img_in_layers(x)
for _ in range(self._rec_depth): # rec depth is not correctly implemented
deter = prev_state["deter"]
# (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter)
@@ -286,7 +287,7 @@ class RSSM(nn.Module):
def _suff_stats_layer(self, name, x):
if self._discrete:
if name == "ims":
x = self._ims_stat_layer(x)
x = self._imgs_stat_layer(x)
elif name == "obs":
x = self._obs_stat_layer(x)
else:
@@ -295,7 +296,7 @@ class RSSM(nn.Module):
return {"logit": logit}
else:
if name == "ims":
x = self._ims_stat_layer(x)
x = self._imgs_stat_layer(x)
elif name == "obs":
x = self._obs_stat_layer(x)
else:
@@ -386,6 +387,7 @@ class MultiEncoder(nn.Module):
act,
norm,
symlog_inputs=symlog_inputs,
name="Encoder",
)
self.outdim += mlp_units
@@ -418,6 +420,7 @@ class MultiDecoder(nn.Module):
cnn_sigmoid,
image_dist,
vector_dist,
outscale,
):
super(MultiDecoder, self).__init__()
excluded = ("is_first", "is_last", "is_terminal")
@@ -444,6 +447,7 @@ class MultiDecoder(nn.Module):
norm,
kernel_size,
minres,
outscale=outscale,
cnn_sigmoid=cnn_sigmoid,
)
if self.mlp_shapes:
@@ -455,6 +459,8 @@ class MultiDecoder(nn.Module):
act,
norm,
vector_dist,
outscale=outscale,
name="Decoder",
)
self._image_dist = image_dist
@@ -491,21 +497,18 @@ class ConvEncoder(nn.Module):
input_shape,
depth=32,
act="SiLU",
norm="LayerNorm",
norm=True,
kernel_size=4,
minres=4,
):
super(ConvEncoder, self).__init__()
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
h, w, input_ch = input_shape
stages = int(np.log2(h) - np.log2(minres))
in_dim = input_ch
out_dim = depth
layers = []
for i in range(int(np.log2(h) - np.log2(minres))):
if i == 0:
in_dim = input_ch
else:
in_dim = 2 ** (i - 1) * depth
out_dim = 2**i * depth
for i in range(stages):
layers.append(
Conv2dSame(
in_channels=in_dim,
@@ -515,15 +518,19 @@ class ConvEncoder(nn.Module):
bias=False,
)
)
layers.append(ChLayerNorm(out_dim))
if norm:
layers.append(ChLayerNorm(out_dim))
layers.append(act())
in_dim = out_dim
out_dim *= 2
h, w = h // 2, w // 2
self.outdim = out_dim * h * w
self.outdim = out_dim // 2 * h * w
self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init)
def forward(self, obs):
obs -= 0.5
# (batch, time, h, w, ch) -> (batch * time, h, w, ch)
x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
# (batch * time, h, w, ch) -> (batch * time, ch, h, w)
@@ -542,7 +549,7 @@ class ConvDecoder(nn.Module):
shape=(3, 64, 64),
depth=32,
act=nn.ELU,
norm=nn.LayerNorm,
norm=True,
kernel_size=4,
minres=4,
outscale=1.0,
@@ -550,29 +557,27 @@ class ConvDecoder(nn.Module):
):
super(ConvDecoder, self).__init__()
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._shape = shape
self._cnn_sigmoid = cnn_sigmoid
layer_num = int(np.log2(shape[1]) - np.log2(minres))
self._minres = minres
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
out_ch = minres**2 * depth * 2 ** (layer_num - 1)
self._embed_size = out_ch
self._linear_layer = nn.Linear(feat_size, self._embed_size)
self._linear_layer.apply(tools.weight_init)
in_dim = self._embed_size // (minres**2)
self._linear_layer = nn.Linear(feat_size, out_ch)
self._linear_layer.apply(tools.uniform_weight_init(outscale))
in_dim = out_ch // (minres**2)
out_dim = in_dim // 2
layers = []
h, w = minres, minres
for i in range(layer_num):
out_dim = self._embed_size // (minres**2) // (2 ** (i + 1))
bias = False
initializer = tools.weight_init
if i == layer_num - 1:
out_dim = self._shape[0]
act = False
bias = True
norm = False
initializer = tools.uniform_weight_init(outscale)
if i != 0:
in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
@@ -593,9 +598,11 @@ class ConvDecoder(nn.Module):
layers.append(ChLayerNorm(out_dim))
if act:
layers.append(act())
[m.apply(initializer) for m in layers[-3:]]
in_dim = out_dim
out_dim //= 2
h, w = h * 2, w * 2
[m.apply(tools.weight_init) for m in layers[:-1]]
layers[-1].apply(tools.uniform_weight_init(outscale))
self.layers = nn.Sequential(*layers)
def calc_same_pad(self, k, s, d):
@@ -613,12 +620,14 @@ class ConvDecoder(nn.Module):
# (batch, time, -1) -> (batch * time, ch, h, w)
x = x.permute(0, 3, 1, 2)
x = self.layers(x)
# (batch, time, -1) -> (batch * time, ch, h, w) necessary???
# (batch, time, -1) -> (batch, time, ch, h, w)
mean = x.reshape(features.shape[:-1] + self._shape)
# (batch * time, ch, h, w) -> (batch * time, h, w, ch)
# (batch, time, ch, h, w) -> (batch, time, h, w, ch)
mean = mean.permute(0, 1, 3, 4, 2)
if self._cnn_sigmoid:
mean = F.sigmoid(mean) - 0.5
mean = F.sigmoid(mean)
else:
mean += 0.5
return mean