modified weight initialization
This commit is contained in:
14
tools.py
14
tools.py
@@ -920,7 +920,9 @@ def weight_init(m):
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = 1.0 / denoms
|
||||
std = np.sqrt(scale) / 0.87962566103423978
|
||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
|
||||
nn.init.trunc_normal_(
|
||||
m.weight.data, mean=0.0, std=std, a=-2.0 * std, b=2.0 * std
|
||||
)
|
||||
if hasattr(m.bias, "data"):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
@@ -940,6 +942,16 @@ def uniform_weight_init(given_scale):
|
||||
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
||||
if hasattr(m.bias, "data"):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||
space = m.kernel_size[0] * m.kernel_size[1]
|
||||
in_num = space * m.in_channels
|
||||
out_num = space * m.out_channels
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = given_scale / denoms
|
||||
limit = np.sqrt(3 * scale)
|
||||
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
||||
if hasattr(m.bias, "data"):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.weight.data.fill_(1.0)
|
||||
if hasattr(m.bias, "data"):
|
||||
|
||||
Reference in New Issue
Block a user