added state input capability
This commit is contained in:
52
tools.py
52
tools.py
@@ -320,24 +320,34 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||
return sample
|
||||
|
||||
|
||||
class TwoHotDistSymlog:
|
||||
def __init__(self, logits=None, low=-20.0, high=20.0, device="cuda"):
|
||||
class DiscDist:
|
||||
def __init__(
|
||||
self,
|
||||
logits,
|
||||
low=-20.0,
|
||||
high=20.0,
|
||||
transfwd=symlog,
|
||||
transbwd=symexp,
|
||||
device="cuda",
|
||||
):
|
||||
self.logits = logits
|
||||
self.probs = torch.softmax(logits, -1)
|
||||
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
||||
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
||||
self.transfwd = transfwd
|
||||
self.transbwd = transbwd
|
||||
|
||||
def mean(self):
|
||||
_mean = self.probs * self.buckets
|
||||
return symexp(torch.sum(_mean, dim=-1, keepdim=True))
|
||||
return self.transbwd(torch.sum(_mean, dim=-1, keepdim=True))
|
||||
|
||||
def mode(self):
|
||||
_mode = self.probs * self.buckets
|
||||
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
|
||||
return self.transbwd(torch.sum(_mode, dim=-1, keepdim=True))
|
||||
|
||||
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
|
||||
def log_prob(self, x):
|
||||
x = symlog(x)
|
||||
x = self.transfwd(x)
|
||||
# x(time, batch, 1)
|
||||
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
|
||||
above = len(self.buckets) - torch.sum(
|
||||
@@ -366,15 +376,35 @@ class TwoHotDistSymlog:
|
||||
return (target * log_pred).sum(-1)
|
||||
|
||||
|
||||
class MSEDist:
|
||||
def __init__(self, mode, agg="sum"):
|
||||
self._mode = mode
|
||||
self._agg = agg
|
||||
|
||||
def mode(self):
|
||||
return self._mode
|
||||
|
||||
def mean(self):
|
||||
return self._mode
|
||||
|
||||
def log_prob(self, value):
|
||||
assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
|
||||
distance = (self._mode - value) ** 2
|
||||
if self._agg == "mean":
|
||||
loss = distance.mean(list(range(len(distance.shape)))[2:])
|
||||
elif self._agg == "sum":
|
||||
loss = distance.sum(list(range(len(distance.shape)))[2:])
|
||||
else:
|
||||
raise NotImplementedError(self._agg)
|
||||
return -loss
|
||||
|
||||
|
||||
class SymlogDist:
|
||||
def __init__(
|
||||
self, mode, dist="mse", agg="sum", tol=1e-8, dim_to_reduce=[-1, -2, -3]
|
||||
):
|
||||
def __init__(self, mode, dist="mse", agg="sum", tol=1e-8):
|
||||
self._mode = mode
|
||||
self._dist = dist
|
||||
self._agg = agg
|
||||
self._tol = tol
|
||||
self._dim_to_reduce = dim_to_reduce
|
||||
|
||||
def mode(self):
|
||||
return symexp(self._mode)
|
||||
@@ -393,9 +423,9 @@ class SymlogDist:
|
||||
else:
|
||||
raise NotImplementedError(self._dist)
|
||||
if self._agg == "mean":
|
||||
loss = distance.mean(self._dim_to_reduce)
|
||||
loss = distance.mean(list(range(len(distance.shape)))[2:])
|
||||
elif self._agg == "sum":
|
||||
loss = distance.sum(self._dim_to_reduce)
|
||||
loss = distance.sum(list(range(len(distance.shape)))[2:])
|
||||
else:
|
||||
raise NotImplementedError(self._agg)
|
||||
return -loss
|
||||
|
||||
Reference in New Issue
Block a user