modified based on author's implementation
This commit is contained in:
146
networks.py
146
networks.py
@@ -59,29 +59,33 @@ class RSSM(nn.Module):
|
||||
if self._shared:
|
||||
inp_dim += self._embed
|
||||
for i in range(self._layers_input):
|
||||
inp_layers.append(nn.Linear(inp_dim, self._hidden))
|
||||
inp_layers.append(self._norm(self._hidden))
|
||||
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
inp_layers.append(self._norm(self._hidden, eps=1e-03))
|
||||
inp_layers.append(self._act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._inp_layers = nn.Sequential(*inp_layers)
|
||||
self._inp_layers.apply(tools.weight_init)
|
||||
|
||||
if cell == "gru":
|
||||
self._cell = GRUCell(self._hidden, self._deter)
|
||||
self._cell.apply(tools.weight_init)
|
||||
elif cell == "gru_layer_norm":
|
||||
self._cell = GRUCell(self._hidden, self._deter, norm=True)
|
||||
self._cell.apply(tools.weight_init)
|
||||
else:
|
||||
raise NotImplementedError(cell)
|
||||
|
||||
img_out_layers = []
|
||||
inp_dim = self._deter
|
||||
for i in range(self._layers_output):
|
||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden))
|
||||
img_out_layers.append(self._norm(self._hidden))
|
||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
img_out_layers.append(self._norm(self._hidden, eps=1e-03))
|
||||
img_out_layers.append(self._act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._img_out_layers = nn.Sequential(*img_out_layers)
|
||||
self._img_out_layers.apply(tools.weight_init)
|
||||
|
||||
obs_out_layers = []
|
||||
if self._temp_post:
|
||||
@@ -89,19 +93,24 @@ class RSSM(nn.Module):
|
||||
else:
|
||||
inp_dim = self._embed
|
||||
for i in range(self._layers_output):
|
||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden))
|
||||
obs_out_layers.append(self._norm(self._hidden))
|
||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
obs_out_layers.append(self._norm(self._hidden, eps=1e-03))
|
||||
obs_out_layers.append(self._act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
self._obs_out_layers = nn.Sequential(*obs_out_layers)
|
||||
self._obs_out_layers.apply(tools.weight_init)
|
||||
|
||||
if self._discrete:
|
||||
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
||||
self._ims_stat_layer.apply(tools.weight_init)
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
else:
|
||||
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._ims_stat_layer.apply(tools.weight_init)
|
||||
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
|
||||
self._obs_stat_layer.apply(tools.weight_init)
|
||||
|
||||
def initial(self, batch_size):
|
||||
deter = torch.zeros(batch_size, self._deter).to(self._device)
|
||||
@@ -178,6 +187,7 @@ class RSSM(nn.Module):
|
||||
def obs_step(self, prev_state, prev_action, embed, sample=True):
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
|
||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||
if self._shared:
|
||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
||||
@@ -200,6 +210,7 @@ class RSSM(nn.Module):
|
||||
# this is used for making future image
|
||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
||||
# (batch, stoch, discrete_num)
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
prev_stoch = prev_state["stoch"]
|
||||
if self._discrete:
|
||||
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
||||
@@ -317,12 +328,15 @@ class ConvEncoder(nn.Module):
|
||||
out_channels=depth,
|
||||
kernel_size=(kernel, kernel),
|
||||
stride=(2, 2),
|
||||
bias=False,
|
||||
)
|
||||
)
|
||||
h, w = h // 2, w // 2
|
||||
# layers.append(norm([depth, h, w]))
|
||||
layers.append(ChLayerNorm(depth))
|
||||
layers.append(act())
|
||||
h, w = h // 2, w // 2
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
def __call__(self, obs):
|
||||
x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:]))
|
||||
@@ -343,6 +357,7 @@ class ConvDecoder(nn.Module):
|
||||
norm=nn.LayerNorm,
|
||||
shape=(3, 64, 64),
|
||||
kernels=(3, 3, 3, 3),
|
||||
outscale=1.0,
|
||||
):
|
||||
super(ConvDecoder, self).__init__()
|
||||
self._inp_depth = inp_depth
|
||||
@@ -358,19 +373,25 @@ class ConvDecoder(nn.Module):
|
||||
self._linear_layer = nn.Linear(inp_depth, self._embed_size)
|
||||
inp_dim = self._embed_size // 16
|
||||
|
||||
cnnt_layers = []
|
||||
layers = []
|
||||
h, w = 4, 4
|
||||
for i, kernel in enumerate(self._kernels):
|
||||
depth = self._embed_size // 16 // (2 ** (i + 1))
|
||||
act = self._act
|
||||
bias = False
|
||||
initializer = tools.weight_init
|
||||
if i == len(self._kernels) - 1:
|
||||
depth = self._shape[0]
|
||||
act = None
|
||||
act = False
|
||||
bias = True
|
||||
norm = False
|
||||
initializer = tools.uniform_weight_init(outscale)
|
||||
|
||||
if i != 0:
|
||||
inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth
|
||||
pad_h, outpad_h = calc_same_pad(k=kernel, s=2, d=1)
|
||||
pad_w, outpad_w = calc_same_pad(k=kernel, s=2, d=1)
|
||||
cnnt_layers.append(
|
||||
pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1)
|
||||
pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1)
|
||||
layers.append(
|
||||
nn.ConvTranspose2d(
|
||||
inp_dim,
|
||||
depth,
|
||||
@@ -378,26 +399,32 @@ class ConvDecoder(nn.Module):
|
||||
2,
|
||||
padding=(pad_h, pad_w),
|
||||
output_padding=(outpad_h, outpad_w),
|
||||
bias=bias,
|
||||
)
|
||||
)
|
||||
if norm:
|
||||
layers.append(ChLayerNorm(depth))
|
||||
if act:
|
||||
layers.append(act())
|
||||
[m.apply(initializer) for m in layers[-3:]]
|
||||
h, w = h * 2, w * 2
|
||||
# cnnt_layers.append(norm([depth, h, w]))
|
||||
if act is not None:
|
||||
cnnt_layers.append(act())
|
||||
self._cnnt_layers = nn.Sequential(*cnnt_layers)
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def calc_same_pad(self, k, s, d):
|
||||
val = d * (k - 1) - s + 1
|
||||
pad = math.ceil(val / 2)
|
||||
outpad = pad * 2 - val
|
||||
return pad, outpad
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = self._linear_layer(features)
|
||||
x = x.reshape([-1, 4, 4, self._embed_size // 16])
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self._cnnt_layers(x)
|
||||
x = self.layers(x)
|
||||
mean = x.reshape(features.shape[:-1] + self._shape)
|
||||
mean = mean.permute(0, 1, 3, 4, 2)
|
||||
return tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
torchd.normal.Normal(mean, 1), len(self._shape)
|
||||
)
|
||||
)
|
||||
return tools.SymlogDist(mean)
|
||||
|
||||
|
||||
class DenseHead(nn.Module):
|
||||
@@ -411,7 +438,7 @@ class DenseHead(nn.Module):
|
||||
norm=nn.LayerNorm,
|
||||
dist="normal",
|
||||
std=1.0,
|
||||
unimix_ratio=0.0,
|
||||
outscale=1.0,
|
||||
):
|
||||
super(DenseHead, self).__init__()
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
@@ -423,27 +450,30 @@ class DenseHead(nn.Module):
|
||||
self._norm = norm
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
self._unimix_ratio = unimix_ratio
|
||||
|
||||
mean_layers = []
|
||||
layers = []
|
||||
for index in range(self._layers):
|
||||
mean_layers.append(nn.Linear(inp_dim, self._units))
|
||||
mean_layers.append(norm(self._units))
|
||||
mean_layers.append(act())
|
||||
layers.append(nn.Linear(inp_dim, self._units, bias=False))
|
||||
layers.append(norm(self._units, eps=1e-03))
|
||||
layers.append(act())
|
||||
if index == 0:
|
||||
inp_dim = self._units
|
||||
mean_layers.append(nn.Linear(inp_dim, np.prod(self._shape)))
|
||||
self._mean_layers = nn.Sequential(*mean_layers)
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
if self._std == "learned":
|
||||
self._std_layer = nn.Linear(self._units, np.prod(self._shape))
|
||||
self.std_layer = nn.Linear(self._units, np.prod(self._shape))
|
||||
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = features
|
||||
mean = self._mean_layers(x)
|
||||
out = self.layers(x)
|
||||
mean = self.mean_layer(out)
|
||||
if self._std == "learned":
|
||||
std = self._std_layer(x)
|
||||
std = torch.softplus(std) + 0.01
|
||||
std = self.std_layer(out)
|
||||
else:
|
||||
std = self._std
|
||||
if self._dist == "normal":
|
||||
@@ -464,8 +494,8 @@ class DenseHead(nn.Module):
|
||||
torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)
|
||||
)
|
||||
)
|
||||
if self._dist == "twohot":
|
||||
return tools.TwoHotDist(logits=mean, unimix_ratio=self._unimix_ratio)
|
||||
if self._dist == "twohot_symlog":
|
||||
return tools.TwoHotDistSymlog(logits=mean)
|
||||
raise NotImplementedError(self._dist)
|
||||
|
||||
|
||||
@@ -481,9 +511,9 @@ class ActionHead(nn.Module):
|
||||
dist="trunc_normal",
|
||||
init_std=0.0,
|
||||
min_std=0.1,
|
||||
action_disc=5,
|
||||
max_std=1.0,
|
||||
temp=0.1,
|
||||
outscale=0,
|
||||
outscale=1.0,
|
||||
):
|
||||
super(ActionHead, self).__init__()
|
||||
self._size = size
|
||||
@@ -493,24 +523,27 @@ class ActionHead(nn.Module):
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._init_std = init_std
|
||||
self._action_disc = action_disc
|
||||
self._temp = temp() if callable(temp) else temp
|
||||
self._outscale = outscale
|
||||
|
||||
pre_layers = []
|
||||
for index in range(self._layers):
|
||||
pre_layers.append(nn.Linear(inp_dim, self._units))
|
||||
pre_layers.append(norm(self._units))
|
||||
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
|
||||
pre_layers.append(norm(self._units, eps=1e-03))
|
||||
pre_layers.append(act())
|
||||
if index == 0:
|
||||
inp_dim = self._units
|
||||
self._pre_layers = nn.Sequential(*pre_layers)
|
||||
self._pre_layers.apply(tools.weight_init)
|
||||
|
||||
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
|
||||
self._dist_layer = nn.Linear(self._units, 2 * self._size)
|
||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
|
||||
self._dist_layer = nn.Linear(self._units, self._size)
|
||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
def __call__(self, features, dtype=None):
|
||||
x = features
|
||||
@@ -539,9 +572,11 @@ class ActionHead(nn.Module):
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == "normal":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, 2, -1)
|
||||
std = F.softplus(std + self._init_std) + self._min_std
|
||||
dist = torchd.normal.Normal(mean, std)
|
||||
mean, std = torch.split(x, [self._size] * 2, -1)
|
||||
std = (self._max_std - self._min_std) * torch.sigmoid(
|
||||
std + 2.0
|
||||
) + self._min_std
|
||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "normal_1":
|
||||
x = self._dist_layer(x)
|
||||
@@ -574,9 +609,9 @@ class GRUCell(nn.Module):
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
self._update_bias = update_bias
|
||||
self._layer = nn.Linear(inp_size + size, 3 * size, bias=norm is not None)
|
||||
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
|
||||
if norm:
|
||||
self._norm = nn.LayerNorm(3 * size)
|
||||
self._norm = nn.LayerNorm(3 * size, eps=1e-03)
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@@ -625,8 +660,13 @@ class Conv2dSame(torch.nn.Conv2d):
|
||||
return ret
|
||||
|
||||
|
||||
def calc_same_pad(k, s, d):
|
||||
val = d * (k - 1) - s + 1
|
||||
pad = math.ceil(val / 2)
|
||||
outpad = pad * 2 - val
|
||||
return pad, outpad
|
||||
class ChLayerNorm(nn.Module):
|
||||
def __init__(self, ch, eps=1e-03):
|
||||
super(ChLayerNorm, self).__init__()
|
||||
self.norm = torch.nn.LayerNorm(ch, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user