merged action head into MLP and modified configs
This commit is contained in:
150
networks.py
150
networks.py
@@ -632,9 +632,14 @@ class MLP(nn.Module):
|
||||
norm="LayerNorm",
|
||||
dist="normal",
|
||||
std=1.0,
|
||||
min_std=0.1,
|
||||
max_std=1.0,
|
||||
temp=0.1,
|
||||
unimix_ratio=0.01,
|
||||
outscale=1.0,
|
||||
symlog_inputs=False,
|
||||
device="cuda",
|
||||
name="NoName",
|
||||
):
|
||||
super(MLP, self).__init__()
|
||||
self._shape = (shape,) if isinstance(shape, int) else shape
|
||||
@@ -647,15 +652,20 @@ class MLP(nn.Module):
|
||||
self._std = std
|
||||
self._symlog_inputs = symlog_inputs
|
||||
self._device = device
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._temp = temp
|
||||
self._unimix_ratio = unimix_ratio
|
||||
|
||||
layers = []
|
||||
self.layers = nn.Sequential()
|
||||
for index in range(self._layers):
|
||||
layers.append(nn.Linear(inp_dim, units, bias=False))
|
||||
layers.append(norm(units, eps=1e-03))
|
||||
layers.append(act())
|
||||
self.layers.add_module(
|
||||
f"{name}_linear{index}", nn.Linear(inp_dim, units, bias=False)
|
||||
)
|
||||
self.layers.add_module(f"{name}_norm{index}", norm(units, eps=1e-03))
|
||||
self.layers.add_module(f"{name}_act{index}", act())
|
||||
if index == 0:
|
||||
inp_dim = units
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.layers.apply(tools.weight_init)
|
||||
|
||||
if isinstance(self._shape, dict):
|
||||
@@ -664,6 +674,7 @@ class MLP(nn.Module):
|
||||
self.mean_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||
if self._std == "learned":
|
||||
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
|
||||
self.std_layer = nn.ModuleDict()
|
||||
for name, shape in self._shape.items():
|
||||
self.std_layer[name] = nn.Linear(inp_dim, np.prod(shape))
|
||||
@@ -672,6 +683,7 @@ class MLP(nn.Module):
|
||||
self.mean_layer = nn.Linear(inp_dim, np.prod(self._shape))
|
||||
self.mean_layer.apply(tools.uniform_weight_init(outscale))
|
||||
if self._std == "learned":
|
||||
assert dist in ("tanh_normal", "normal", "trunc_normal", "huber"), dist
|
||||
self.std_layer = nn.Linear(units, np.prod(self._shape))
|
||||
self.std_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
@@ -680,6 +692,7 @@ class MLP(nn.Module):
|
||||
if self._symlog_inputs:
|
||||
x = tools.symlog(x)
|
||||
out = self.layers(x)
|
||||
# Used for encoder output
|
||||
if self._shape is None:
|
||||
return out
|
||||
if isinstance(self._shape, dict):
|
||||
@@ -701,98 +714,9 @@ class MLP(nn.Module):
|
||||
return self.dist(self._dist, mean, std, self._shape)
|
||||
|
||||
def dist(self, dist, mean, std, shape):
|
||||
if dist == "normal":
|
||||
return tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
torchd.normal.Normal(mean, std), len(shape)
|
||||
)
|
||||
)
|
||||
if dist == "huber":
|
||||
return tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
||||
)
|
||||
)
|
||||
if dist == "binary":
|
||||
return tools.Bernoulli(
|
||||
torchd.independent.Independent(
|
||||
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
|
||||
)
|
||||
)
|
||||
if dist == "symlog_disc":
|
||||
return tools.DiscDist(logits=mean, device=self._device)
|
||||
if dist == "symlog_mse":
|
||||
return tools.SymlogDist(mean)
|
||||
raise NotImplementedError(dist)
|
||||
|
||||
|
||||
class ActionHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inp_dim,
|
||||
size,
|
||||
layers,
|
||||
units,
|
||||
act=nn.ELU,
|
||||
norm=nn.LayerNorm,
|
||||
dist="trunc_normal",
|
||||
init_std=0.0,
|
||||
min_std=0.1,
|
||||
max_std=1.0,
|
||||
temp=0.1,
|
||||
outscale=1.0,
|
||||
unimix_ratio=0.01,
|
||||
):
|
||||
super(ActionHead, self).__init__()
|
||||
self._size = size
|
||||
self._layers = layers
|
||||
self._units = units
|
||||
self._dist = dist
|
||||
act = getattr(torch.nn, act)
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._init_std = init_std
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._temp = temp() if callable(temp) else temp
|
||||
|
||||
pre_layers = []
|
||||
for index in range(self._layers):
|
||||
pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
|
||||
pre_layers.append(norm(self._units, eps=1e-03))
|
||||
pre_layers.append(act())
|
||||
if index == 0:
|
||||
inp_dim = self._units
|
||||
self._pre_layers = nn.Sequential(*pre_layers)
|
||||
self._pre_layers.apply(tools.weight_init)
|
||||
|
||||
if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
|
||||
self._dist_layer = nn.Linear(self._units, 2 * self._size)
|
||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
|
||||
self._dist_layer = nn.Linear(self._units, self._size)
|
||||
self._dist_layer.apply(tools.uniform_weight_init(outscale))
|
||||
|
||||
def forward(self, features, dtype=None):
|
||||
x = features
|
||||
x = self._pre_layers(x)
|
||||
if self._dist == "tanh_normal":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, 2, -1)
|
||||
mean = torch.tanh(mean)
|
||||
std = F.softplus(std + self._init_std) + self._min_std
|
||||
dist = torchd.normal.Normal(mean, std)
|
||||
dist = torchd.transformed_distribution.TransformedDistribution(
|
||||
dist, tools.TanhBijector()
|
||||
)
|
||||
dist = torchd.independent.Independent(dist, 1)
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == "tanh_normal_5":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, 2, -1)
|
||||
mean = 5 * torch.tanh(mean / 5)
|
||||
std = F.softplus(std + 5) + 5
|
||||
std = F.softplus(std) + self._min_std
|
||||
dist = torchd.normal.Normal(mean, std)
|
||||
dist = torchd.transformed_distribution.TransformedDistribution(
|
||||
dist, tools.TanhBijector()
|
||||
@@ -800,33 +724,41 @@ class ActionHead(nn.Module):
|
||||
dist = torchd.independent.Independent(dist, 1)
|
||||
dist = tools.SampleDist(dist)
|
||||
elif self._dist == "normal":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, [self._size] * 2, -1)
|
||||
std = (self._max_std - self._min_std) * torch.sigmoid(
|
||||
std + 2.0
|
||||
) + self._min_std
|
||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "normal_1":
|
||||
mean = self._dist_layer(x)
|
||||
dist = torchd.normal.Normal(mean, 1)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
|
||||
elif self._dist == "normal_std_fixed":
|
||||
dist = torchd.normal.Normal(mean, self._std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "trunc_normal":
|
||||
x = self._dist_layer(x)
|
||||
mean, std = torch.split(x, [self._size] * 2, -1)
|
||||
mean = torch.tanh(mean)
|
||||
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "onehot":
|
||||
x = self._dist_layer(x)
|
||||
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
|
||||
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
||||
elif self._dist == "onehot_gumble":
|
||||
x = self._dist_layer(x)
|
||||
temp = self._temp
|
||||
dist = tools.ContDist(torchd.gumbel.Gumbel(x, 1 / temp))
|
||||
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
|
||||
elif dist == "huber":
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
||||
)
|
||||
)
|
||||
elif dist == "binary":
|
||||
dist = tools.Bernoulli(
|
||||
torchd.independent.Independent(
|
||||
torchd.bernoulli.Bernoulli(logits=mean), len(shape)
|
||||
)
|
||||
)
|
||||
elif dist == "symlog_disc":
|
||||
dist = tools.DiscDist(logits=mean, device=self._device)
|
||||
elif dist == "symlog_mse":
|
||||
dist = tools.SymlogDist(mean)
|
||||
else:
|
||||
raise NotImplementedError(self._dist)
|
||||
raise NotImplementedError(dist)
|
||||
return dist
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user