erased unused options
This commit is contained in:
146
networks.py
146
networks.py
@@ -16,18 +16,13 @@ class RSSM(nn.Module):
|
||||
stoch=30,
|
||||
deter=200,
|
||||
hidden=200,
|
||||
layers_input=1,
|
||||
layers_output=1,
|
||||
rec_depth=1,
|
||||
shared=False,
|
||||
discrete=False,
|
||||
act="SiLU",
|
||||
norm="LayerNorm",
|
||||
norm=True,
|
||||
mean_act="none",
|
||||
std_act="softplus",
|
||||
temp_post=True,
|
||||
min_std=0.1,
|
||||
cell="gru",
|
||||
unimix_ratio=0.01,
|
||||
initial="learned",
|
||||
num_actions=None,
|
||||
@@ -39,16 +34,11 @@ class RSSM(nn.Module):
|
||||
self._deter = deter
|
||||
self._hidden = hidden
|
||||
self._min_std = min_std
|
||||
self._layers_input = layers_input
|
||||
self._layers_output = layers_output
|
||||
self._rec_depth = rec_depth
|
||||
self._shared = shared
|
||||
self._discrete = discrete
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._mean_act = mean_act
|
||||
self._std_act = std_act
|
||||
self._temp_post = temp_post
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._initial = initial
|
||||
self._num_actions = num_actions
|
||||
@@ -60,47 +50,30 @@ class RSSM(nn.Module):
|
||||
inp_dim = self._stoch * self._discrete + num_actions
|
||||
else:
|
||||
inp_dim = self._stoch + num_actions
|
||||
if self._shared:
|
||||
inp_dim += self._embed
|
||||
for i in range(self._layers_input):
|
||||
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
inp_layers.append(norm(self._hidden, eps=1e-03))
|
||||
inp_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
if norm:
|
||||
inp_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
|
||||
inp_layers.append(act())
|
||||
self._img_in_layers = nn.Sequential(*inp_layers)
|
||||
self._img_in_layers.apply(tools.weight_init)
|
||||
if cell == "gru":
|
||||
self._cell = GRUCell(self._hidden, self._deter)
|
||||
self._cell.apply(tools.weight_init)
|
||||
elif cell == "gru_layer_norm":
|
||||
self._cell = GRUCell(self._hidden, self._deter, norm=True)
|
||||
self._cell.apply(tools.weight_init)
|
||||
else:
|
||||
raise NotImplementedError(cell)
|
||||
self._cell = GRUCell(self._hidden, self._deter, norm=norm)
|
||||
self._cell.apply(tools.weight_init)
|
||||
|
||||
img_out_layers = []
|
||||
inp_dim = self._deter
|
||||
for i in range(self._layers_output):
|
||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
img_out_layers.append(norm(self._hidden, eps=1e-03))
|
||||
img_out_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
if norm:
|
||||
img_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
|
||||
img_out_layers.append(act())
|
||||
self._img_out_layers = nn.Sequential(*img_out_layers)
|
||||
self._img_out_layers.apply(tools.weight_init)
|
||||
|
||||
obs_out_layers = []
|
||||
if self._temp_post:
|
||||
inp_dim = self._deter + self._embed
|
||||
else:
|
||||
inp_dim = self._embed
|
||||
for i in range(self._layers_output):
|
||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
obs_out_layers.append(norm(self._hidden, eps=1e-03))
|
||||
obs_out_layers.append(act())
|
||||
if i == 0:
|
||||
inp_dim = self._hidden
|
||||
inp_dim = self._deter + self._embed
|
||||
obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
|
||||
if norm:
|
||||
obs_out_layers.append(nn.LayerNorm(self._hidden, eps=1e-03))
|
||||
obs_out_layers.append(act())
|
||||
self._obs_out_layers = nn.Sequential(*obs_out_layers)
|
||||
self._obs_out_layers.apply(tools.weight_init)
|
||||
|
||||
@@ -200,9 +173,6 @@ class RSSM(nn.Module):
|
||||
return dist
|
||||
|
||||
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
|
||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||
|
||||
# initialize all prev_state
|
||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||
prev_state = self.initial(len(is_first))
|
||||
@@ -223,41 +193,28 @@ class RSSM(nn.Module):
|
||||
val * (1.0 - is_first_r) + init_state[key] * is_first_r
|
||||
)
|
||||
|
||||
prior = self.img_step(prev_state, prev_action, None, sample)
|
||||
if self._shared:
|
||||
post = self.img_step(prev_state, prev_action, embed, sample)
|
||||
prior = self.img_step(prev_state, prev_action)
|
||||
x = torch.cat([prior["deter"], embed], -1)
|
||||
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
|
||||
x = self._obs_out_layers(x)
|
||||
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
|
||||
stats = self._suff_stats_layer("obs", x)
|
||||
if sample:
|
||||
stoch = self.get_dist(stats).sample()
|
||||
else:
|
||||
if self._temp_post:
|
||||
x = torch.cat([prior["deter"], embed], -1)
|
||||
else:
|
||||
x = embed
|
||||
# (batch_size, prior_deter + embed) -> (batch_size, hidden)
|
||||
x = self._obs_out_layers(x)
|
||||
# (batch_size, hidden) -> (batch_size, stoch, discrete_num)
|
||||
stats = self._suff_stats_layer("obs", x)
|
||||
if sample:
|
||||
stoch = self.get_dist(stats).sample()
|
||||
else:
|
||||
stoch = self.get_dist(stats).mode()
|
||||
post = {"stoch": stoch, "deter": prior["deter"], **stats}
|
||||
stoch = self.get_dist(stats).mode()
|
||||
post = {"stoch": stoch, "deter": prior["deter"], **stats}
|
||||
return post, prior
|
||||
|
||||
# this is used for making future image
|
||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
||||
def img_step(self, prev_state, prev_action, sample=True):
|
||||
# (batch, stoch, discrete_num)
|
||||
prev_stoch = prev_state["stoch"]
|
||||
if self._discrete:
|
||||
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
||||
# (batch, stoch, discrete_num) -> (batch, stoch * discrete_num)
|
||||
prev_stoch = prev_stoch.reshape(shape)
|
||||
if self._shared:
|
||||
if embed is None:
|
||||
shape = list(prev_action.shape[:-1]) + [self._embed]
|
||||
embed = torch.zeros(shape)
|
||||
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed)
|
||||
x = torch.cat([prev_stoch, prev_action, embed], -1)
|
||||
else:
|
||||
x = torch.cat([prev_stoch, prev_action], -1)
|
||||
# (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action)
|
||||
x = torch.cat([prev_stoch, prev_action], -1)
|
||||
# (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
|
||||
x = self._img_in_layers(x)
|
||||
for _ in range(self._rec_depth): # rec depth is not correctly implemented
|
||||
@@ -508,7 +465,7 @@ class ConvEncoder(nn.Module):
|
||||
layers = []
|
||||
for i in range(stages):
|
||||
layers.append(
|
||||
Conv2dSame(
|
||||
Conv2dSamePad(
|
||||
in_channels=in_dim,
|
||||
out_channels=out_dim,
|
||||
kernel_size=kernel_size,
|
||||
@@ -517,7 +474,7 @@ class ConvEncoder(nn.Module):
|
||||
)
|
||||
)
|
||||
if norm:
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
layers.append(ImgChLayerNorm(out_dim))
|
||||
layers.append(act())
|
||||
in_dim = out_dim
|
||||
out_dim *= 2
|
||||
@@ -593,7 +550,7 @@ class ConvDecoder(nn.Module):
|
||||
)
|
||||
)
|
||||
if norm:
|
||||
layers.append(ChLayerNorm(out_dim))
|
||||
layers.append(ImgChLayerNorm(out_dim))
|
||||
if act:
|
||||
layers.append(act())
|
||||
in_dim = out_dim
|
||||
@@ -637,7 +594,7 @@ class MLP(nn.Module):
|
||||
layers,
|
||||
units,
|
||||
act="SiLU",
|
||||
norm="LayerNorm",
|
||||
norm=True,
|
||||
dist="normal",
|
||||
std=1.0,
|
||||
min_std=0.1,
|
||||
@@ -654,11 +611,9 @@ class MLP(nn.Module):
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
if self._shape is not None and len(self._shape) == 0:
|
||||
self._shape = (1,)
|
||||
self._layers = layers
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
self._std = std if isinstance(std, str) else torch.tensor((std,), device=device)
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._absmax = absmax
|
||||
@@ -668,13 +623,16 @@ class MLP(nn.Module):
|
||||
self._device = device
|
||||
|
||||
self.layers = nn.Sequential()
|
||||
for index in range(self._layers):
|
||||
for i in range(layers):
|
||||
self.layers.add_module(
|
||||
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
|
||||
f"{name}_linear{i}", nn.Linear(inp_dim, units, bias=False)
|
||||
)
|
||||
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
|
||||
self.layers.add_module(f"{name}_act{index}", act())
|
||||
if index == 0:
|
||||
if norm:
|
||||
self.layers.add_module(
|
||||
f"{name}_norm{i}", nn.LayerNorm(units, eps=1e-03)
|
||||
)
|
||||
self.layers.add_module(f"{name}_act{i}", act())
|
||||
if i == 0:
|
||||
inp_dim = units
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
@@ -783,16 +741,18 @@ class MLP(nn.Module):
|
||||
|
||||
|
||||
class GRUCell(nn.Module):
|
||||
def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1):
|
||||
def __init__(self, inp_size, size, norm=True, act=torch.tanh, update_bias=-1):
|
||||
super(GRUCell, self).__init__()
|
||||
self._inp_size = inp_size
|
||||
self._size = size
|
||||
self._act = act
|
||||
self._norm = norm
|
||||
self._update_bias = update_bias
|
||||
self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
|
||||
self.layers = nn.Sequential()
|
||||
self.layers.add_module(
|
||||
"GRU_linear", nn.Linear(inp_size + size, 3 * size, bias=False)
|
||||
)
|
||||
if norm:
|
||||
self._norm = nn.LayerNorm(3 * size, eps=1e-03)
|
||||
self.layers.add_module("GRU_norm", nn.LayerNorm(3 * size, eps=1e-03))
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@@ -800,9 +760,7 @@ class GRUCell(nn.Module):
|
||||
|
||||
def forward(self, inputs, state):
|
||||
state = state[0] # Keras wraps the state in a list.
|
||||
parts = self._layer(torch.cat([inputs, state], -1))
|
||||
if self._norm:
|
||||
parts = self._norm(parts)
|
||||
parts = self.layers(torch.cat([inputs, state], -1))
|
||||
reset, cand, update = torch.split(parts, [self._size] * 3, -1)
|
||||
reset = torch.sigmoid(reset)
|
||||
cand = self._act(reset * cand)
|
||||
@@ -811,7 +769,7 @@ class GRUCell(nn.Module):
|
||||
return output, [output]
|
||||
|
||||
|
||||
class Conv2dSame(torch.nn.Conv2d):
|
||||
class Conv2dSamePad(torch.nn.Conv2d):
|
||||
def calc_same_pad(self, i, k, s, d):
|
||||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
||||
|
||||
@@ -841,9 +799,9 @@ class Conv2dSame(torch.nn.Conv2d):
|
||||
return ret
|
||||
|
||||
|
||||
class ChLayerNorm(nn.Module):
|
||||
class ImgChLayerNorm(nn.Module):
|
||||
def __init__(self, ch, eps=1e-03):
|
||||
super(ChLayerNorm, self).__init__()
|
||||
super(ImgChLayerNorm, self).__init__()
|
||||
self.norm = torch.nn.LayerNorm(ch, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
Reference in New Issue
Block a user