limit action values in sampling stage
This commit is contained in:
30
networks.py
30
networks.py
@@ -200,9 +200,8 @@ class RSSM(nn.Module):
|
||||
return dist
|
||||
|
||||
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
|
||||
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
|
||||
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
|
||||
# initialize all prev_state
|
||||
if prev_state == None or torch.sum(is_first) == len(is_first):
|
||||
@@ -246,7 +245,6 @@ class RSSM(nn.Module):
|
||||
# this is used for making future image
|
||||
def img_step(self, prev_state, prev_action, embed=None, sample=True):
|
||||
# (batch, stoch, discrete_num)
|
||||
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
|
||||
prev_stoch = prev_state["stoch"]
|
||||
if self._discrete:
|
||||
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
|
||||
@@ -644,6 +642,7 @@ class MLP(nn.Module):
|
||||
std=1.0,
|
||||
min_std=0.1,
|
||||
max_std=1.0,
|
||||
absmax=None,
|
||||
temp=0.1,
|
||||
unimix_ratio=0.01,
|
||||
outscale=1.0,
|
||||
@@ -660,12 +659,13 @@ class MLP(nn.Module):
|
||||
norm = getattr(torch.nn, norm)
|
||||
self._dist = dist
|
||||
self._std = std
|
||||
self._symlog_inputs = symlog_inputs
|
||||
self._device = device
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._absmax = absmax
|
||||
self._temp = temp
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._symlog_inputs = symlog_inputs
|
||||
self._device = device
|
||||
|
||||
self.layers = nn.Sequential()
|
||||
for index in range(self._layers):
|
||||
@@ -738,23 +738,33 @@ class MLP(nn.Module):
|
||||
std + 2.0
|
||||
) + self._min_std
|
||||
dist = torchd.normal.Normal(torch.tanh(mean), std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||
)
|
||||
elif self._dist == "normal_std_fixed":
|
||||
dist = torchd.normal.Normal(mean, self._std)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||
)
|
||||
elif self._dist == "trunc_normal":
|
||||
mean = torch.tanh(mean)
|
||||
std = 2 * torch.sigmoid(std / 2) + self._min_std
|
||||
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(dist, 1), absmax=self._absmax
|
||||
)
|
||||
elif self._dist == "onehot":
|
||||
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
|
||||
elif self._dist == "onehot_gumble":
|
||||
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
|
||||
dist = tools.ContDist(
|
||||
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
|
||||
)
|
||||
elif dist == "huber":
|
||||
dist = tools.ContDist(
|
||||
torchd.independent.Independent(
|
||||
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
|
||||
tools.UnnormalizedHuber(mean, std, 1.0),
|
||||
len(shape),
|
||||
absmax=self._absmax,
|
||||
)
|
||||
)
|
||||
elif dist == "binary":
|
||||
|
||||
Reference in New Issue
Block a user