modified weight initialization
This commit is contained in:
81
networks.py
81
networks.py
@@ -68,9 +68,8 @@ class RSSM(nn.Module):
|
||||
inp_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._inp_layers = nn.Sequential(*inp_layers)
|
||||
self._inp_layers.apply(tools.weight_init)
|
||||
|
||||
self._img_in_layers = nn.Sequential(*inp_layers)
|
||||
self._img_in_layers.apply(tools.weight_init)
|
||||
if cell == "gru":
|
||||
self._cell = GRUCell(self._hidden, self._deter)
|
||||
self._cell.apply(tools.weight_init)
|
||||
@@ -106,15 +105,17 @@ class RSSM(nn.Module):
|
||||
self._obs_out_layers.apply(tools.weight_init)
|
||||
|
||||
if self._discrete:
|
||||
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
||||
self._ims_stat_layer.apply(tools.weight_init)
|
||||
self._imgs_stat_layer = nn.Linear(
|
||||
self._hidden, self._stoch * self._discrete
|
||||
)
|
||||
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||
else:
|
||||
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._ims_stat_layer.apply(tools.weight_init)
|
||||
self._imgs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._imgs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
self._obs_stat_layer.apply(tools.uniform_weight_init(1.0))
|
||||
|
||||
if self._initial == "learned":
|
||||
self.W = torch.nn.Parameter(
|
||||
@@ -260,7 +261,7 @@ class RSSM(nn.Module):
|
||||
else:
|
||||
x = torch.cat([prev_stoch, prev_action], -1)
|
||||
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
|
||||
x = self._inp_layers(x)
|
||||
x = self._img_in_layers(x)
|
||||
for _ in range(self._rec_depth): # rec depth is not correctly implemented
|
||||
deter = prev_state["deter"]
|
||||
# (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter)
|
||||
@@ -286,7 +287,7 @@ class RSSM(nn.Module):
|
||||
def _suff_stats_layer(self, name, x):
|
||||
if self._discrete:
|
||||
if name == "ims":
|
||||
x = self._ims_stat_layer(x)
|
||||
x = self._imgs_stat_layer(x)
|
||||
elif name == "obs":
|
||||
x = self._obs_stat_layer(x)
|
||||
else:
|
||||
@@ -295,7 +296,7 @@ class RSSM(nn.Module):
|
||||
return {"logit": logit}
|
||||
else:
|
||||
if name == "ims":
|
||||
x = self._ims_stat_layer(x)
|
||||
x = self._imgs_stat_layer(x)
|
||||
elif name == "obs":
|
||||
x = self._obs_stat_layer(x)
|
||||
else:
|
||||
@@ -386,6 +387,7 @@ class MultiEncoder(nn.Module):
|
||||
act,
|
||||
norm,
|
||||
symlog_inputs=symlog_inputs,
|
||||
name="Encoder",
|
||||
)
|
||||
self.outdim += mlp_units
|
||||
|
||||
@@ -418,6 +420,7 @@ class MultiDecoder(nn.Module):
|
||||
cnn_sigmoid,
|
||||
image_dist,
|
||||
vector_dist,
|
||||
outscale,
|
||||
):
|
||||
super(MultiDecoder, self).__init__()
|
||||
excluded = ("is_first", "is_last", "is_terminal")
|
||||
@@ -444,6 +447,7 @@ class MultiDecoder(nn.Module):
|
||||
norm,
|
||||
kernel_size,
|
||||
minres,
|
||||
outscale=outscale,
|
||||
cnn_sigmoid=cnn_sigmoid,
|
||||
)
|
||||
if self.mlp_shapes:
|
||||
@@ -455,6 +459,8 @@ class MultiDecoder(nn.Module):
|
||||
act,
|
||||
norm,
|
||||
vector_dist,
|
||||
outscale=outscale,
|
||||
name="Decoder",
|
||||
)
|
||||
self._image_dist = image_dist
|
||||
|
||||
@@ -491,21 +497,18 @@ class ConvEncoder(nn.Module):
|
||||
input_shape,
|
||||
depth=32,
|
||||
act="SiLU",
|
||||
norm="LayerNorm",
|
||||
norm=True,
|
||||
kernel_size=4,
|
||||
minres=4,
|
||||
):
|
||||
super(ConvEncoder, self).__init__()
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
h, w, input_ch = input_shape
|
||||
stages = int(np.log2(h) - np.log2(minres))
|
||||
in_dim = input_ch
|
||||
out_dim = depth
|
||||
layers = []
|
||||
for i in range(int(np.log2(h) - np.log2(minres))):
|
||||
if i == 0:
|
||||
in_dim = input_ch
|
||||
else:
|
||||
in_dim = 2 ** (i - 1) * depth
|
||||
out_dim = 2**i * depth
|
||||
for i in range(stages):
|
||||
layers.append(
|
||||
Conv2dSame(
|
||||
in_channels=in_dim,
|
||||
@@ -515,15 +518,19 @@ class ConvEncoder(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
if norm:
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
layers.append(act())
|
||||
in_dim = out_dim
|
||||
out_dim *= 2
|
||||
h, w = h // 2, w // 2
|
||||
|
||||
self.outdim = out_dim * h * w
|
||||
self.outdim = out_dim // 2 * h * w
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
def forward(self, obs):
|
||||
obs -= 0.5
|
||||
# (batch, time, h, w, ch) -> (batch * time, h, w, ch)
|
||||
x = obs.reshape((-1,) + tuple(obs.shape[-3:]))
|
||||
# (batch * time, h, w, ch) -> (batch * time, ch, h, w)
|
||||
@@ -542,7 +549,7 @@ class ConvDecoder(nn.Module):
|
||||
shape=(3, 64, 64),
|
||||
depth=32,
|
||||
act=nn.ELU,
|
||||
norm=nn.LayerNorm,
|
||||
norm=True,
|
||||
kernel_size=4,
|
||||
minres=4,
|
||||
outscale=1.0,
|
||||
@@ -550,29 +557,27 @@ class ConvDecoder(nn.Module):
|
||||
):
|
||||
super(ConvDecoder, self).__init__()
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._shape = shape
|
||||
self._cnn_sigmoid = cnn_sigmoid
|
||||
layer_num = int(np.log2(shape[1]) - np.log2(minres))
|
||||
self._minres = minres
|
||||
self._embed_size = minres**2 * depth * 2 ** (layer_num - 1)
|
||||
out_ch = minres**2 * depth * 2 ** (layer_num - 1)
|
||||
self._embed_size = out_ch
|
||||
|
||||
self._linear_layer = nn.Linear(feat_size, self._embed_size)
|
||||
self._linear_layer.apply(tools.weight_init)
|
||||
in_dim = self._embed_size // (minres**2)
|
||||
self._linear_layer = nn.Linear(feat_size, out_ch)
|
||||
self._linear_layer.apply(tools.uniform_weight_init(outscale))
|
||||
in_dim = out_ch // (minres**2)
|
||||
out_dim = in_dim // 2
|
||||
|
||||
layers = []
|
||||
h, w = minres, minres
|
||||
for i in range(layer_num):
|
||||
out_dim = self._embed_size // (minres**2) // (2 ** (i + 1))
|
||||
bias = False
|
||||
initializer = tools.weight_init
|
||||
if i == layer_num - 1:
|
||||
out_dim = self._shape[0]
|
||||
act = False
|
||||
bias = True
|
||||
norm = False
|
||||
initializer = tools.uniform_weight_init(outscale)
|
||||
|
||||
if i != 0:
|
||||
in_dim = 2 ** (layer_num - (i - 1) - 2) * depth
|
||||
@@ -593,9 +598,11 @@ class ConvDecoder(nn.Module):
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
if act:
|
||||
layers.append(act())
|
||||
[m.apply(initializer) for m in layers[-3:]]
|
||||
in_dim = out_dim
|
||||
out_dim //= 2
|
||||
h, w = h * 2, w * 2
|
||||
|
||||
[m.apply(tools.weight_init) for m in layers[:-1]]
|
||||
layers[-1].apply(tools.uniform_weight_init(outscale))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def calc_same_pad(self, k, s, d):
|
||||
@@ -613,12 +620,14 @@ class ConvDecoder(nn.Module):
|
||||
# (batch, time, -1) -> (batch * time, ch, h, w)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.layers(x)
|
||||
# (batch, time, -1) -> (batch * time, ch, h, w) necessary???
|
||||
# (batch, time, -1) -> (batch, time, ch, h, w)
|
||||
mean = x.reshape(features.shape[:-1] + self._shape)
|
||||
# (batch * time, ch, h, w) -> (batch * time, h, w, ch)
|
||||
# (batch, time, ch, h, w) -> (batch, time, h, w, ch)
|
||||
mean = mean.permute(0, 1, 3, 4, 2)
|
||||
if self._cnn_sigmoid:
|
||||
mean = F.sigmoid(mean) - 0.5
|
||||
mean = F.sigmoid(mean)
|
||||
else:
|
||||
mean += 0.5
|
||||
return mean
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user