updated result, requirements and torch version
This commit is contained in:
@@ -514,6 +514,7 @@ class ActionHead(nn.Module):
|
||||
max_std=1.0,
|
||||
temp=0.1,
|
||||
outscale=1.0,
|
||||
unimix_ratio=0.01,
|
||||
):
|
||||
super(ActionHead, self).__init__()
|
||||
self._size = size
|
||||
@@ -525,6 +526,7 @@ class ActionHead(nn.Module):
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._init_std = init_std
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._temp = temp() if callable(temp) else temp
|
||||
|
||||
pre_layers = []
|
||||
@@ -591,7 +593,7 @@ class ActionHead(nn.Module):
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "onehot":
|
||||
x = self._dist_layer(x)
|
||||
dist = tools.OneHotDist(x)
|
||||
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
|
||||
elif self._dist == "onehot_gumble":
|
||||
x = self._dist_layer(x)
|
||||
temp = self._temp
|
||||
|
||||
Reference in New Issue
Block a user