erased unused options

This commit is contained in:
NM512
2024-01-05 23:23:09 +09:00
parent a27711ab96
commit 7f66ed5333
6 changed files with 84 additions and 211 deletions

View File

@@ -16,18 +16,13 @@ class RSSM(nn.Module):
stoch=30,
deter=200,
hidden=200,
layers_input=1,
layers_output=1,
rec_depth=1,
shared=False,
discrete=False,
act="SiLU",
norm="LayerNorm",
norm=True,
mean_act="none",
std_act="softplus",
temp_post=True,
min_std=0.1,
cell="gru",
unimix_ratio=0.01,
initial="learned",
num_actions=None,
@@ -39,16 +34,11 @@ class RSSM(nn.Module):
self._deter = deter
self._hidden = hidden
self._min_std = min_std
self._layers_input = layers_input
self._layers_output = layers_output
self._rec_depth = rec_depth
self._shared = shared
self._discrete = discrete
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._mean_act = mean_act
self._std_act = std_act
self._temp_post = temp_post
self._unimix_ratio = unimix_ratio
self._initial = initial
self._num_actions = num_actions
@@ -60,47 +50,30 @@ class RSSM(nn.Module):
inp_dim = self._stoch * self._discrete + num_actions
else:
inp_dim = self._stoch + num_actions
if self._shared:
inp_dim += self._embed
for i in range(self._layers_input):
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
inp_layers.append(norm(self._hidden, eps=1e-03))
inp_layers.append(act())
if i == 0:
inp_dim = self._hidden
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
if norm:
inp_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
inp_layers.append(act())
self._img_in_layers = nn.Sequential(*inp_layers)
self._img_in_layers.apply(tools.weight_init)
if cell == "gru":
self._cell = GRUCell(self._hidden, self._deter)
self._cell.apply(tools.weight_init)
elif cell == "gru_layer_norm":
self._cell = GRUCell(self._hidden, self._deter, norm=True)
self._cell.apply(tools.weight_init)
else:
raise NotImplementedError(cell)
self._cell = GRUCell(self._hidden, self._deter, norm=norm)
self._cell.apply(tools.weight_init)
img_out_layers = []
inp_dim = self._deter
for i in range(self._layers_output):
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
img_out_layers.append(norm(self._hidden, eps=1e-03))
img_out_layers.append(act())
if i == 0:
inp_dim = self._hidden
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
if norm:
img_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
img_out_layers.append(act())
self._img_out_layers = nn.Sequential(*img_out_layers)
self._img_out_layers.apply(tools.weight_init)
obs_out_layers = []
if self._temp_post:
inp_dim = self._deter + self._embed
else:
inp_dim = self._embed
for i in range(self._layers_output):
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
obs_out_layers.append(norm(self._hidden, eps=1e-03))
obs_out_layers.append(act())
if i == 0:
inp_dim = self._hidden
inp_dim = self._deter + self._embed
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
if norm:
obs_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
obs_out_layers.append(act())
self._obs_out_layers = nn.Sequential(*obs_out_layers)
self._obs_out_layers.apply(tools.weight_init)
@@ -200,9 +173,6 @@ class RSSM(nn.Module):
return dist
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
# initialize all prev_state
if prev_state == None or torch.sum(is_first) == len(is_first):
prev_state = self.initial(len(is_first))
@@ -223,41 +193,28 @@ class RSSM(nn.Module):
val * (1.0 - is_first_r) + init_state[key] * is_first_r
)
prior = self.img_step(prev_state, prev_action, None, sample)
if self._shared:
post = self.img_step(prev_state, prev_action, embed, sample)
prior = self.img_step(prev_state, prev_action)
x = torch.cat([prior["deter"], embed], -1)
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
x = self._obs_out_layers(x)
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
stats = self._suff_stats_layer("obs", x)
if sample:
stoch = self.get_dist(stats).sample()
else:
if self._temp_post:
x = torch.cat([prior["deter"], embed], -1)
else:
x = embed
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
x = self._obs_out_layers(x)
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
stats = self._suff_stats_layer("obs", x)
if sample:
stoch = self.get_dist(stats).sample()
else:
stoch = self.get_dist(stats).mode()
post = {"stoch": stoch, "deter": prior["deter"], **stats}
stoch = self.get_dist(stats).mode()
post = {"stoch": stoch, "deter": prior["deter"], **stats}
return post, prior
# this is used for making future image
def img_step(self, prev_state, prev_action, embed=None, sample=True):
def img_step(self, prev_state, prev_action, sample=True):
# (batch, stoch, discrete_num)
prev_stoch = prev_state["stoch"]
if self._discrete:
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
# (batch, stoch, discrete_num) -> (batch, stoch * discrete_num)
prev_stoch = prev_stoch.reshape(shape)
if self._shared:
if embed is None:
shape = list(prev_action.shape[:-1]) + [self._embed]
embed = torch.zeros(shape)
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed)
x = torch.cat([prev_stoch, prev_action, embed], -1)
else:
x = torch.cat([prev_stoch, prev_action], -1)
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action)
x = torch.cat([prev_stoch, prev_action], -1)
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
x = self._img_in_layers(x)
for _ in range(self._rec_depth): # rec depth is not correctly implemented
@@ -508,7 +465,7 @@ class ConvEncoder(nn.Module):
layers = []
for i in range(stages):
layers.append(
Conv2dSame(
Conv2dSamePad(
in_channels=in_dim,
out_channels=out_dim,
kernel_size=kernel_size,
@@ -517,7 +474,7 @@ class ConvEncoder(nn.Module):
)
)
if norm:
layers.append(ChLayerNorm(out_dim))
layers.append(ImgChLayerNorm(out_dim))
layers.append(act())
in_dim = out_dim
out_dim *= 2
@@ -593,7 +550,7 @@ class ConvDecoder(nn.Module):
)
)
if norm:
layers.append(ChLayerNorm(out_dim))
layers.append(ImgChLayerNorm(out_dim))
if act:
layers.append(act())
in_dim = out_dim
@@ -637,7 +594,7 @@ class MLP(nn.Module):
layers,
units,
act="SiLU",
norm="LayerNorm",
norm=True,
dist="normal",
std=1.0,
min_std=0.1,
@@ -654,11 +611,9 @@ class MLP(nn.Module):
self._shape = (shape,) if isinstance(shape, int) else shape
if self._shape is not None and len(self._shape) == 0:
self._shape = (1,)
self._layers = layers
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._dist = dist
self._std = std
self._std = std if isinstance(std, str) else torch.tensor((std,), device=device)
self._min_std = min_std
self._max_std = max_std
self._absmax = absmax
@@ -668,13 +623,16 @@ class MLP(nn.Module):
self._device = device
self.layers = nn.Sequential()
for index in range(self._layers):
for i in range(layers):
self.layers.add_module(
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False)
)
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
self.layers.add_module(f"{name}_act{index}", act())
if index == 0:
if norm:
self.layers.add_module(
f"{name}_norm{i}", nn.LayerNorm(units, eps=1e-03)
)
self.layers.add_module(f"{name}_act{i}", act())
if i == 0:
inp_dim = units
self.layers.apply(tools.weight_init)
@@ -783,16 +741,18 @@ class MLP(nn.Module):
class GRUCell(nn.Module):
def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1):
def __init__(self, inp_size, size, norm=True, act=torch.tanh, update_bias=-1):
super(GRUCell, self).__init__()
self._inp_size = inp_size
self._size = size
self._act = act
self._norm = norm
self._update_bias = update_bias
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
self.layers = nn.Sequential()
self.layers.add_module(
"GRU_linear", nn.Linear(inp_size + size, 3 * size, bias=False)
)
if norm:
self._norm = nn.LayerNorm(3 * size, eps=1e-03)
self.layers.add_module("GRU_norm", nn.LayerNorm(3 * size, eps=1e-03))
@property
def state_size(self):
@@ -800,9 +760,7 @@ class GRUCell(nn.Module):
def forward(self, inputs, state):
state = state[0] # Keras wraps the state in a list.
parts = self._layer(torch.cat([inputs, state], -1))
if self._norm:
parts = self._norm(parts)
parts = self.layers(torch.cat([inputs, state], -1))
reset, cand, update = torch.split(parts, [self._size] * 3, -1)
reset = torch.sigmoid(reset)
cand = self._act(reset * cand)
@@ -811,7 +769,7 @@ class GRUCell(nn.Module):
return output, [output]
class Conv2dSame(torch.nn.Conv2d):
class Conv2dSamePad(torch.nn.Conv2d):
def calc_same_pad(self, i, k, s, d):
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
@@ -841,9 +799,9 @@ class Conv2dSame(torch.nn.Conv2d):
return ret
class ChLayerNorm(nn.Module):
class ImgChLayerNorm(nn.Module):
def __init__(self, ch, eps=1e-03):
super(ChLayerNorm, self).__init__()
super(ImgChLayerNorm, self).__init__()
self.norm = torch.nn.LayerNorm(ch, eps=eps)
def forward(self, x):