limit action values in sampling stage

This commit is contained in:
NM512
2024-01-05 11:42:45 +09:00
parent a9e85e8b7c
commit a27711ab96
3 changed files with 33 additions and 15 deletions

View File

@@ -562,10 +562,11 @@ class SymlogDist:
class ContDist:
def __init__(self, dist=None):
def __init__(self, dist=None, absmax=None):
super().__init__()
self._dist = dist
self.mean = dist.mean
self.absmax = absmax
def __getattr__(self, name):
return getattr(self._dist, name)
@@ -574,10 +575,16 @@ class ContDist:
return self._dist.entropy()
def mode(self):
return self._dist.mean
out = self._dist.mean
if self.absmax is not None:
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
return out
def sample(self, sample_shape=()):
return self._dist.rsample(sample_shape)
out = self._dist.rsample(sample_shape)
if self.absmax is not None:
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
return out
def log_prob(self, x):
return self._dist.log_prob(x)