modified based on author's implementation

This commit is contained in:
NM512
2023-03-18 08:38:23 +09:00
parent a678a509b9
commit 6273444394
6 changed files with 371 additions and 229 deletions

View File

@@ -59,29 +59,33 @@ class RSSM(nn.Module):
if self._shared:
inp_dim += self._embed
for i in range(self._layers_input):
inp_layers.append(nn.Linear(inp_dim, self._hidden))
inp_layers.append(self._norm(self._hidden))
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
inp_layers.append(self._norm(self._hidden, eps=1e-03))
inp_layers.append(self._act())
if i == 0:
inp_dim = self._hidden
self._inp_layers = nn.Sequential(*inp_layers)
self._inp_layers.apply(tools.weight_init)
if cell == "gru":
self._cell = GRUCell(self._hidden, self._deter)
self._cell.apply(tools.weight_init)
elif cell == "gru_layer_norm":
self._cell = GRUCell(self._hidden, self._deter, norm=True)
self._cell.apply(tools.weight_init)
else:
raise NotImplementedError(cell)
img_out_layers = []
inp_dim = self._deter
for i in range(self._layers_output):
img_out_layers.append(nn.Linear(inp_dim, self._hidden))
img_out_layers.append(self._norm(self._hidden))
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
img_out_layers.append(self._norm(self._hidden, eps=1e-03))
img_out_layers.append(self._act())
if i == 0:
inp_dim = self._hidden
self._img_out_layers = nn.Sequential(*img_out_layers)
self._img_out_layers.apply(tools.weight_init)
obs_out_layers = []
if self._temp_post:
@@ -89,19 +93,24 @@ class RSSM(nn.Module):
else:
inp_dim = self._embed
for i in range(self._layers_output):
obs_out_layers.append(nn.Linear(inp_dim, self._hidden))
obs_out_layers.append(self._norm(self._hidden))
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
obs_out_layers.append(self._norm(self._hidden, eps=1e-03))
obs_out_layers.append(self._act())
if i == 0:
inp_dim = self._hidden
self._obs_out_layers = nn.Sequential(*obs_out_layers)
self._obs_out_layers.apply(tools.weight_init)
if self._discrete:
self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
self._ims_stat_layer.apply(tools.weight_init)
self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
self._obs_stat_layer.apply(tools.weight_init)
else:
self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._ims_stat_layer.apply(tools.weight_init)
self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
self._obs_stat_layer.apply(tools.weight_init)
def initial(self, batch_size):
deter = torch.zeros(batch_size, self._deter).to(self._device)
@@ -178,6 +187,7 @@ class RSSM(nn.Module):
def obs_step(self, prev_state, prev_action, embed, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
@@ -200,6 +210,7 @@ class RSSM(nn.Module):
# this is used for making future image
def img_step(self, prev_state, prev_action, embed=None, sample=True):
# (batch, stoch, discrete_num)
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prev_stoch = prev_state["stoch"]
if self._discrete:
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
@@ -317,12 +328,15 @@ class ConvEncoder(nn.Module):
out_channels=depth,
kernel_size=(kernel, kernel),
stride=(2, 2),
bias=False,
)
)
h, w = h // 2, w // 2
# layers.append(norm([depth, h, w]))
layers.append(ChLayerNorm(depth))
layers.append(act())
h, w = h // 2, w // 2
self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init)
def __call__(self, obs):
x = obs["image"].reshape((-1,) + tuple(obs["image"].shape[-3:]))
@@ -343,6 +357,7 @@ class ConvDecoder(nn.Module):
norm=nn.LayerNorm,
shape=(3, 64, 64),
kernels=(3, 3, 3, 3),
outscale=1.0,
):
super(ConvDecoder, self).__init__()
self._inp_depth = inp_depth
@@ -358,19 +373,25 @@ class ConvDecoder(nn.Module):
self._linear_layer = nn.Linear(inp_depth, self._embed_size)
inp_dim = self._embed_size // 16
cnnt_layers = []
layers = []
h, w = 4, 4
for i, kernel in enumerate(self._kernels):
depth = self._embed_size // 16 // (2 ** (i + 1))
act = self._act
bias = False
initializer = tools.weight_init
if i == len(self._kernels) - 1:
depth = self._shape[0]
act = None
act = False
bias = True
norm = False
initializer = tools.uniform_weight_init(outscale)
if i != 0:
inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth
pad_h, outpad_h = calc_same_pad(k=kernel, s=2, d=1)
pad_w, outpad_w = calc_same_pad(k=kernel, s=2, d=1)
cnnt_layers.append(
pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1)
pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1)
layers.append(
nn.ConvTranspose2d(
inp_dim,
depth,
@@ -378,26 +399,32 @@ class ConvDecoder(nn.Module):
2,
padding=(pad_h, pad_w),
output_padding=(outpad_h, outpad_w),
bias=bias,
)
)
if norm:
layers.append(ChLayerNorm(depth))
if act:
layers.append(act())
[m.apply(initializer) for m in layers[-3:]]
h, w = h * 2, w * 2
# cnnt_layers.append(norm([depth, h, w]))
if act is not None:
cnnt_layers.append(act())
self._cnnt_layers = nn.Sequential(*cnnt_layers)
self.layers = nn.Sequential(*layers)
def calc_same_pad(self, k, s, d):
val = d * (k - 1) - s + 1
pad = math.ceil(val / 2)
outpad = pad * 2 - val
return pad, outpad
def __call__(self, features, dtype=None):
x = self._linear_layer(features)
x = x.reshape([-1, 4, 4, self._embed_size // 16])
x = x.permute(0, 3, 1, 2)
x = self._cnnt_layers(x)
x = self.layers(x)
mean = x.reshape(features.shape[:-1] + self._shape)
mean = mean.permute(0, 1, 3, 4, 2)
return tools.ContDist(
torchd.independent.Independent(
torchd.normal.Normal(mean, 1), len(self._shape)
)
)
return tools.SymlogDist(mean)
class DenseHead(nn.Module):
@@ -411,7 +438,7 @@ class DenseHead(nn.Module):
norm=nn.LayerNorm,
dist="normal",
std=1.0,
unimix_ratio=0.0,
outscale=1.0,
):
super(DenseHead, self).__init__()
self._shape = (shape,) if isinstance(shape, int) else shape
@@ -423,27 +450,30 @@ class DenseHead(nn.Module):
self._norm = norm
self._dist = dist
self._std = std
self._unimix_ratio = unimix_ratio
mean_layers = []
layers = []
for index in range(self._layers):
mean_layers.append(nn.Linear(inp_dim, self._units))
mean_layers.append(norm(self._units))
mean_layers.append(act())
layers.append(nn.Linear(inp_dim, self._units, bias=False))
layers.append(norm(self._units, eps=1e-03))
layers.append(act())
if index == 0:
inp_dim = self._units
mean_layers.append(nn.Linear(inp_dim, np.prod(self._shape)))
self._mean_layers = nn.Sequential(*mean_layers)
self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init)
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
self.mean_layer.apply(tools.uniform_weight_init(outscale))
if self._std == "learned":
self._std_layer = nn.Linear(self._units, np.prod(self._shape))
self.std_layer = nn.Linear(self._units, np.prod(self._shape))
self.std_layer.apply(tools.uniform_weight_init(outscale))
def __call__(self, features, dtype=None):
x = features
mean = self._mean_layers(x)
out = self.layers(x)
mean = self.mean_layer(out)
if self._std == "learned":
std = self._std_layer(x)
std = torch.softplus(std) + 0.01
std = self.std_layer(out)
else:
std = self._std
if self._dist == "normal":
@@ -464,8 +494,8 @@ class DenseHead(nn.Module):
torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)
)
)
if self._dist == "twohot":
return tools.TwoHotDist(logits=mean, unimix_ratio=self._unimix_ratio)
if self._dist == "twohot_symlog":
return tools.TwoHotDistSymlog(logits=mean)
raise NotImplementedError(self._dist)
@@ -481,9 +511,9 @@ class ActionHead(nn.Module):
dist="trunc_normal",
init_std=0.0,
min_std=0.1,
action_disc=5,
max_std=1.0,
temp=0.1,
outscale=0,
outscale=1.0,
):
super(ActionHead, self).__init__()
self._size = size
@@ -493,24 +523,27 @@ class ActionHead(nn.Module):
self._act = act
self._norm = norm
self._min_std = min_std
self._max_std = max_std
self._init_std = init_std
self._action_disc = action_disc
self._temp = temp() if callable(temp) else temp
self._outscale = outscale
pre_layers = []
for index in range(self._layers):
pre_layers.append(nn.Linear(inp_dim, self._units))
pre_layers.append(norm(self._units))
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
pre_layers.append(norm(self._units, eps=1e-03))
pre_layers.append(act())
if index == 0:
inp_dim = self._units
self._pre_layers = nn.Sequential(*pre_layers)
self._pre_layers.apply(tools.weight_init)
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
self._dist_layer = nn.Linear(self._units, 2 * self._size)
self._dist_layer.apply(tools.uniform_weight_init(outscale))
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
self._dist_layer = nn.Linear(self._units, self._size)
self._dist_layer.apply(tools.uniform_weight_init(outscale))
def __call__(self, features, dtype=None):
x = features
@@ -539,9 +572,11 @@ class ActionHead(nn.Module):
dist = tools.SampleDist(dist)
elif self._dist == "normal":
x = self._dist_layer(x)
mean, std = torch.split(x, 2, -1)
std = F.softplus(std + self._init_std) + self._min_std
dist = torchd.normal.Normal(mean, std)
mean, std = torch.split(x, [self._size] * 2, -1)
std = (self._max_std - self._min_std) * torch.sigmoid(
std + 2.0
) + self._min_std
dist = torchd.normal.Normal(torch.tanh(mean), std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "normal_1":
x = self._dist_layer(x)
@@ -574,9 +609,9 @@ class GRUCell(nn.Module):
self._act = act
self._norm = norm
self._update_bias = update_bias
self._layer = nn.Linear(inp_size + size, 3 * size, bias=norm is not None)
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
if norm:
self._norm = nn.LayerNorm(3 * size)
self._norm = nn.LayerNorm(3 * size, eps=1e-03)
@property
def state_size(self):
@@ -625,8 +660,13 @@ class Conv2dSame(torch.nn.Conv2d):
return ret
def calc_same_pad(k, s, d):
val = d * (k - 1) - s + 1
pad = math.ceil(val / 2)
outpad = pad * 2 - val
return pad, outpad
class ChLayerNorm(nn.Module):
def __init__(self, ch, eps=1e-03):
super(ChLayerNorm, self).__init__()
self.norm = torch.nn.LayerNorm(ch, eps=eps)
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
return x