updated result, requirements and torch version

This commit is contained in:
NM512
2023-03-24 07:51:57 +09:00
parent 2504426164
commit 942eae10a9
6 changed files with 36 additions and 25 deletions

View File

@@ -514,6 +514,7 @@ class ActionHead(nn.Module):
max_std=1.0,
temp=0.1,
outscale=1.0,
unimix_ratio=0.01,
):
super(ActionHead, self).__init__()
self._size = size
@@ -525,6 +526,7 @@ class ActionHead(nn.Module):
self._min_std = min_std
self._max_std = max_std
self._init_std = init_std
self._unimix_ratio = unimix_ratio
self._temp = temp() if callable(temp) else temp
pre_layers = []
@@ -591,7 +593,7 @@ class ActionHead(nn.Module):
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "onehot":
x = self._dist_layer(x)
dist = tools.OneHotDist(x)
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
elif self._dist == "onehot_gumble":
x = self._dist_layer(x)
temp = self._temp