merged action head into MLP and modified configs

This commit is contained in:
NM512
2024-01-05 10:26:48 +09:00
parent e0f2017e28
commit e0487f8206
5 changed files with 133 additions and 231 deletions

View File

@@ -632,9 +632,14 @@ class MLP(nn.Module):
norm="LayerNorm",
dist="normal",
std=1.0,
min_std=0.1,
max_std=1.0,
temp=0.1,
unimix_ratio=0.01,
outscale=1.0,
symlog_inputs=False,
device="cuda",
name="NoName",
):
super(MLP, self).__init__()
self._shape = (shape,) if isinstance(shape, int) else shape
@@ -647,15 +652,20 @@ class MLP(nn.Module):
self._std = std
self._symlog_inputs = symlog_inputs
self._device = device
self._min_std = min_std
self._max_std = max_std
self._temp = temp
self._unimix_ratio = unimix_ratio
layers = []
self.layers = nn.Sequential()
for index in range(self._layers):
layers.append(nn.Linear(inp_dim, units, bias=False))
layers.append(norm(units, eps=1e-03))
layers.append(act())
self.layers.add_module(
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
)
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
self.layers.add_module(f"{name}_act{index}", act())
if index == 0:
inp_dim = units
self.layers = nn.Sequential(*layers)
self.layers.apply(tools.weight_init)
if isinstance(self._shape, dict):
@@ -664,6 +674,7 @@ class MLP(nn.Module):
self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
self.mean_layer.apply(tools.uniform_weight_init(outscale))
if self._std == "learned":
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
self.std_layer = nn.ModuleDict()
for name, shape in self._shape.items():
self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
@@ -672,6 +683,7 @@ class MLP(nn.Module):
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
self.mean_layer.apply(tools.uniform_weight_init(outscale))
if self._std == "learned":
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
self.std_layer = nn.Linear(units, np.prod(self._shape))
self.std_layer.apply(tools.uniform_weight_init(outscale))
@@ -680,6 +692,7 @@ class MLP(nn.Module):
if self._symlog_inputs:
x = tools.symlog(x)
out = self.layers(x)
# Used for encoder output
if self._shape is None:
return out
if isinstance(self._shape, dict):
@@ -701,98 +714,9 @@ class MLP(nn.Module):
return self.dist(self._dist, mean, std, self._shape)
def dist(self, dist, mean, std, shape):
if dist == "normal":
return tools.ContDist(
torchd.independent.Independent(
torchd.normal.Normal(mean, std), len(shape)
)
)
if dist == "huber":
return tools.ContDist(
torchd.independent.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
)
)
if dist == "binary":
return tools.Bernoulli(
torchd.independent.Independent(
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
)
)
if dist == "symlog_disc":
return tools.DiscDist(logits=mean, device=self._device)
if dist == "symlog_mse":
return tools.SymlogDist(mean)
raise NotImplementedError(dist)
class ActionHead(nn.Module):
def __init__(
self,
inp_dim,
size,
layers,
units,
act=nn.ELU,
norm=nn.LayerNorm,
dist="trunc_normal",
init_std=0.0,
min_std=0.1,
max_std=1.0,
temp=0.1,
outscale=1.0,
unimix_ratio=0.01,
):
super(ActionHead, self).__init__()
self._size = size
self._layers = layers
self._units = units
self._dist = dist
act = getattr(torch.nn, act)
norm = getattr(torch.nn, norm)
self._min_std = min_std
self._max_std = max_std
self._init_std = init_std
self._unimix_ratio = unimix_ratio
self._temp = temp() if callable(temp) else temp
pre_layers = []
for index in range(self._layers):
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
pre_layers.append(norm(self._units, eps=1e-03))
pre_layers.append(act())
if index == 0:
inp_dim = self._units
self._pre_layers = nn.Sequential(*pre_layers)
self._pre_layers.apply(tools.weight_init)
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
self._dist_layer = nn.Linear(self._units, 2 * self._size)
self._dist_layer.apply(tools.uniform_weight_init(outscale))
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
self._dist_layer = nn.Linear(self._units, self._size)
self._dist_layer.apply(tools.uniform_weight_init(outscale))
def forward(self, features, dtype=None):
x = features
x = self._pre_layers(x)
if self._dist == "tanh_normal":
x = self._dist_layer(x)
mean, std = torch.split(x, 2, -1)
mean = torch.tanh(mean)
std = F.softplus(std + self._init_std) + self._min_std
dist = torchd.normal.Normal(mean, std)
dist = torchd.transformed_distribution.TransformedDistribution(
dist, tools.TanhBijector()
)
dist = torchd.independent.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == "tanh_normal_5":
x = self._dist_layer(x)
mean, std = torch.split(x, 2, -1)
mean = 5 * torch.tanh(mean / 5)
std = F.softplus(std + 5) + 5
std = F.softplus(std) + self._min_std
dist = torchd.normal.Normal(mean, std)
dist = torchd.transformed_distribution.TransformedDistribution(
dist, tools.TanhBijector()
@@ -800,33 +724,41 @@ class ActionHead(nn.Module):
dist = torchd.independent.Independent(dist, 1)
dist = tools.SampleDist(dist)
elif self._dist == "normal":
x = self._dist_layer(x)
mean, std = torch.split(x, [self._size] * 2, -1)
std = (self._max_std - self._min_std) * torch.sigmoid(
std + 2.0
) + self._min_std
dist = torchd.normal.Normal(torch.tanh(mean), std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "normal_1":
mean = self._dist_layer(x)
dist = torchd.normal.Normal(mean, 1)
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
elif self._dist == "normal_std_fixed":
dist = torchd.normal.Normal(mean, self._std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "trunc_normal":
x = self._dist_layer(x)
mean, std = torch.split(x, [self._size] * 2, -1)
mean = torch.tanh(mean)
std = 2 * torch.sigmoid(std / 2) + self._min_std
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "onehot":
x = self._dist_layer(x)
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
elif self._dist == "onehot_gumble":
x = self._dist_layer(x)
temp = self._temp
dist = tools.ContDist(torchd.gumbel.Gumbel(x, 1 / temp))
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
elif dist == "huber":
dist = tools.ContDist(
torchd.independent.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
)
)
elif dist == "binary":
dist = tools.Bernoulli(
torchd.independent.Independent(
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
)
)
elif dist == "symlog_disc":
dist = tools.DiscDist(logits=mean, device=self._device)
elif dist == "symlog_mse":
dist = tools.SymlogDist(mean)
else:
raise NotImplementedError(self._dist)
raise NotImplementedError(dist)
return dist