modified weight initialization

This commit is contained in:
NM512
2024-01-05 10:46:54 +09:00
parent 4fe9b29ebe
commit a9e85e8b7c
3 changed files with 61 additions and 40 deletions

View File

@@ -920,7 +920,9 @@ def weight_init(m):
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
nn.init.trunc_normal_(
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
@@ -940,6 +942,16 @@ def uniform_weight_init(given_scale):
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
space = m.kernel_size[0] * m.kernel_size[1]
in_num = space * m.in_channels
out_num = space * m.out_channels
denoms = (in_num + out_num) / 2.0
scale = given_scale / denoms
limit = np.sqrt(3 * scale)
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, "data"):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, "data"):