added state input capability

This commit is contained in:
NM512
2023-05-14 23:38:46 +09:00
parent 3ebb8ad617
commit b984e69b6e
8 changed files with 369 additions and 142 deletions

View File

@@ -320,24 +320,34 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
return sample
class TwoHotDistSymlog:
def __init__(self, logits=None, low=-20.0, high=20.0, device="cuda"):
class DiscDist:
def __init__(
self,
logits,
low=-20.0,
high=20.0,
transfwd=symlog,
transbwd=symexp,
device="cuda",
):
self.logits = logits
self.probs = torch.softmax(logits, -1)
self.buckets = torch.linspace(low, high, steps=255).to(device)
self.width = (self.buckets[-1] - self.buckets[0]) / 255
self.transfwd = transfwd
self.transbwd = transbwd
def mean(self):
_mean = self.probs * self.buckets
return symexp(torch.sum(_mean, dim=-1, keepdim=True))
return self.transbwd(torch.sum(_mean, dim=-1, keepdim=True))
def mode(self):
_mode = self.probs * self.buckets
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
return self.transbwd(torch.sum(_mode, dim=-1, keepdim=True))
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
def log_prob(self, x):
x = symlog(x)
x = self.transfwd(x)
# x(time, batch, 1)
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
above = len(self.buckets) - torch.sum(
@@ -366,15 +376,35 @@ class TwoHotDistSymlog:
return (target * log_pred).sum(-1)
class MSEDist:
def __init__(self, mode, agg="sum"):
self._mode = mode
self._agg = agg
def mode(self):
return self._mode
def mean(self):
return self._mode
def log_prob(self, value):
assert self._mode.shape == value.shape, (self._mode.shape, value.shape)
distance = (self._mode - value) ** 2
if self._agg == "mean":
loss = distance.mean(list(range(len(distance.shape)))[2:])
elif self._agg == "sum":
loss = distance.sum(list(range(len(distance.shape)))[2:])
else:
raise NotImplementedError(self._agg)
return -loss
class SymlogDist:
def __init__(
self, mode, dist="mse", agg="sum", tol=1e-8, dim_to_reduce=[-1, -2, -3]
):
def __init__(self, mode, dist="mse", agg="sum", tol=1e-8):
self._mode = mode
self._dist = dist
self._agg = agg
self._tol = tol
self._dim_to_reduce = dim_to_reduce
def mode(self):
return symexp(self._mode)
@@ -393,9 +423,9 @@ class SymlogDist:
else:
raise NotImplementedError(self._dist)
if self._agg == "mean":
loss = distance.mean(self._dim_to_reduce)
loss = distance.mean(list(range(len(distance.shape)))[2:])
elif self._agg == "sum":
loss = distance.sum(self._dim_to_reduce)
loss = distance.sum(list(range(len(distance.shape)))[2:])
else:
raise NotImplementedError(self._agg)
return -loss