limit action values in sampling stage

This commit is contained in:
NM512
2024-01-05 11:42:45 +09:00
parent a9e85e8b7c
commit a27711ab96
3 changed files with 33 additions and 15 deletions

View File

@@ -200,9 +200,8 @@ class RSSM(nn.Module):
return dist
def obs_step(self, prev_state, prev_action, embed, is_first, sample=True):
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
# if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _imgs_stat_layer)
# otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
# initialize all prev_state
if prev_state == None or torch.sum(is_first) == len(is_first):
@@ -246,7 +245,6 @@ class RSSM(nn.Module):
# this is used for making future image
def img_step(self, prev_state, prev_action, embed=None, sample=True):
# (batch, stoch, discrete_num)
prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
prev_stoch = prev_state["stoch"]
if self._discrete:
shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
@@ -644,6 +642,7 @@ class MLP(nn.Module):
std=1.0,
min_std=0.1,
max_std=1.0,
absmax=None,
temp=0.1,
unimix_ratio=0.01,
outscale=1.0,
@@ -660,12 +659,13 @@ class MLP(nn.Module):
norm = getattr(torch.nn, norm)
self._dist = dist
self._std = std
self._symlog_inputs = symlog_inputs
self._device = device
self._min_std = min_std
self._max_std = max_std
self._absmax = absmax
self._temp = temp
self._unimix_ratio = unimix_ratio
self._symlog_inputs = symlog_inputs
self._device = device
self.layers = nn.Sequential()
for index in range(self._layers):
@@ -738,23 +738,33 @@ class MLP(nn.Module):
std + 2.0
) + self._min_std
dist = torchd.normal.Normal(torch.tanh(mean), std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1), absmax=1.0)
dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax
)
elif self._dist == "normal_std_fixed":
dist = torchd.normal.Normal(mean, self._std)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax
)
elif self._dist == "trunc_normal":
mean = torch.tanh(mean)
std = 2 * torch.sigmoid(std / 2) + self._min_std
dist = tools.SafeTruncatedNormal(mean, std, -1, 1)
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
dist = tools.ContDist(
torchd.independent.Independent(dist, 1), absmax=self._absmax
)
elif self._dist == "onehot":
dist = tools.OneHotDist(mean, unimix_ratio=self._unimix_ratio)
elif self._dist == "onehot_gumble":
dist = tools.ContDist(torchd.gumbel.Gumbel(mean, 1 / self._temp))
dist = tools.ContDist(
torchd.gumbel.Gumbel(mean, 1 / self._temp), absmax=self._absmax
)
elif dist == "huber":
dist = tools.ContDist(
torchd.independent.Independent(
tools.UnnormalizedHuber(mean, std, 1.0), len(shape)
tools.UnnormalizedHuber(mean, std, 1.0),
len(shape),
absmax=self._absmax,
)
)
elif dist == "binary":