limit action values in sampling stage
This commit is contained in:
13
tools.py
13
tools.py
@@ -562,10 +562,11 @@ class SymlogDist:
|
||||
|
||||
|
||||
class ContDist:
|
||||
def __init__(self, dist=None):
|
||||
def __init__(self, dist=None, absmax=None):
|
||||
super().__init__()
|
||||
self._dist = dist
|
||||
self.mean = dist.mean
|
||||
self.absmax = absmax
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._dist, name)
|
||||
@@ -574,10 +575,16 @@ class ContDist:
|
||||
return self._dist.entropy()
|
||||
|
||||
def mode(self):
|
||||
return self._dist.mean
|
||||
out = self._dist.mean
|
||||
if self.absmax is not None:
|
||||
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
|
||||
return out
|
||||
|
||||
def sample(self, sample_shape=()):
|
||||
return self._dist.rsample(sample_shape)
|
||||
out = self._dist.rsample(sample_shape)
|
||||
if self.absmax is not None:
|
||||
out *= (self.absmax / torch.clip(torch.abs(out), min=self.absmax)).detach()
|
||||
return out
|
||||
|
||||
def log_prob(self, x):
|
||||
return self._dist.log_prob(x)
|
||||
|
||||
Reference in New Issue
Block a user