modified based on author's implementation
This commit is contained in:
156
tools.py
156
tools.py
@@ -17,6 +17,14 @@ from torch.utils.data import Dataset
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
to_np = lambda x: x.detach().cpu().numpy()
|
||||
|
||||
def symlog(x):
|
||||
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
|
||||
|
||||
def symexp(x):
|
||||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
|
||||
|
||||
class RequiresGrad:
|
||||
|
||||
def __init__(self, model):
|
||||
@@ -269,11 +277,13 @@ class SampleDist:
|
||||
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||
|
||||
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
|
||||
if logits is not None and probs is None and unimix_ratio > 0.0:
|
||||
if logits is not None and unimix_ratio > 0.0:
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
|
||||
logits = None
|
||||
super().__init__(logits=logits, probs=probs)
|
||||
logits = torch.log(probs)
|
||||
super().__init__(logits=logits, probs=None)
|
||||
else:
|
||||
super().__init__(logits=logits, probs=probs)
|
||||
|
||||
def mode(self):
|
||||
_mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1])
|
||||
@@ -290,42 +300,81 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||
return sample
|
||||
|
||||
|
||||
class TwoHotDist(torchd.one_hot_categorical.OneHotCategorical):
|
||||
class TwoHotDistSymlog():
|
||||
|
||||
def __init__(self, logits=None, probs=None, unimix_ratio=0.0, device='cuda'):
|
||||
if logits is not None and probs is None and unimix_ratio > 0.0:
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
|
||||
logits = None
|
||||
super().__init__(logits=logits, probs=probs)
|
||||
|
||||
self.buckets = torch.linspace(-20.0, 20.0, steps=255).to(device)
|
||||
def __init__(self, logits=None, low=-20.0, high=20.0, device='cuda'):
|
||||
self.logits = logits
|
||||
self.probs = torch.softmax(logits, -1)
|
||||
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
||||
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
||||
|
||||
def mean(self):
|
||||
print("mean called")
|
||||
_mode = self.probs * self.buckets
|
||||
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
|
||||
|
||||
def mode(self):
|
||||
_mode = super().probs * self.buckets
|
||||
return torch.sum(_mode, dim=-1, keepdim=True)
|
||||
_mode = self.probs * self.buckets
|
||||
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
|
||||
|
||||
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
|
||||
def log_prob(self, x):
|
||||
x = symlog(x)
|
||||
# x(time, batch, 1)
|
||||
x = (x - self.buckets[0]) / self.width
|
||||
lower_indices = (x).to(torch.int64)
|
||||
# lower_indices is idnside 0 ~ len(buckets)-2
|
||||
lower_indices = torch.clip(lower_indices, max=len(self.buckets)-2)
|
||||
# upper_indices is inside 1 ~ len(buckets)-1
|
||||
upper_indices = lower_indices + 1
|
||||
lower_weight = torch.abs(x - upper_indices).squeeze(-1)
|
||||
upper_weight = torch.abs(x - lower_indices).squeeze(-1)
|
||||
# (time, batch, 1) -> (time, batch, bucket_class)
|
||||
lower_log_prob = super().log_prob(F.one_hot(lower_indices.squeeze(-1), num_classes=len(self.buckets)))
|
||||
upper_log_prob = super().log_prob(F.one_hot(upper_indices.squeeze(-1), num_classes=len(self.buckets)))
|
||||
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) -1
|
||||
above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1)
|
||||
below = torch.clip(below, 0, len(self.buckets)-1)
|
||||
above = torch.clip(above, 0, len(self.buckets)-1)
|
||||
equal = (below == above)
|
||||
|
||||
# label = lower_log_prob * lower_weight + upper_log_prob * upper_weight
|
||||
# # (time, batch, bucket_class) -> (time, batch)
|
||||
# cross_entropy = torch.sum(torch.log(super().probs) * label, axis=-1)
|
||||
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
|
||||
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
|
||||
total = dist_to_below + dist_to_above
|
||||
weight_below = dist_to_above / total
|
||||
weight_above = dist_to_below / total
|
||||
target = (
|
||||
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] +
|
||||
F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None])
|
||||
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
|
||||
target = target.squeeze(-2)
|
||||
|
||||
return lower_weight * lower_log_prob + upper_weight * upper_log_prob
|
||||
return (target * log_pred).sum(-1)
|
||||
|
||||
def log_prob_target(self, target):
|
||||
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
|
||||
return (target * log_pred).sum(-1)
|
||||
|
||||
class SymlogDist():
|
||||
def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]):
|
||||
self._mode = mode
|
||||
self._dist = dist
|
||||
self._agg = agg
|
||||
self._tol = tol
|
||||
self._dim_to_reduce = dim_to_reduce
|
||||
|
||||
def mode(self):
|
||||
return symexp(self._mode)
|
||||
|
||||
def mean(self):
|
||||
return symexp(self._mode)
|
||||
|
||||
def log_prob(self, value):
|
||||
assert self._mode.shape == value.shape
|
||||
if self._dist == 'mse':
|
||||
distance = (self._mode - symlog(value)) ** 2.0
|
||||
distance = torch.where(distance < self._tol, 0, distance)
|
||||
elif self._dist == 'abs':
|
||||
distance = torch.abs(self._mode - symlog(value))
|
||||
distance = torch.where(distance < self._tol, 0, distance)
|
||||
else:
|
||||
raise NotImplementedError(self._dist)
|
||||
if self._agg == 'mean':
|
||||
loss = distance.mean(self._dim_to_reduce)
|
||||
elif self._agg == 'sum':
|
||||
loss = distance.sum(self._dim_to_reduce)
|
||||
else:
|
||||
raise NotImplementedError(self._agg)
|
||||
return -loss
|
||||
|
||||
class ContDist:
|
||||
|
||||
@@ -438,6 +487,7 @@ def static_scan_for_lambda_return(fn, inputs, start):
|
||||
indices = reversed(indices)
|
||||
flag = True
|
||||
for index in indices:
|
||||
# (inputs, pcont) -> (inputs[index], pcont[index])
|
||||
inp = lambda x: (_input[x] for _input in inputs)
|
||||
last = fn(last, *inp(index))
|
||||
if flag:
|
||||
@@ -446,6 +496,7 @@ def static_scan_for_lambda_return(fn, inputs, start):
|
||||
else:
|
||||
outputs = torch.cat([outputs, last], dim=-1)
|
||||
outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1])
|
||||
outputs = torch.flip(outputs, [1])
|
||||
outputs = torch.unbind(outputs, dim=0)
|
||||
return outputs
|
||||
|
||||
@@ -687,14 +738,53 @@ def schedule(string, step):
|
||||
|
||||
def weight_init(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.orthogonal_(m.weight.data)
|
||||
in_num = m.in_features
|
||||
out_num = m.out_features
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = 1.0 / denoms
|
||||
std = np.sqrt(scale) / 0.87962566103423978
|
||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||
gain = nn.init.calculate_gain('relu')
|
||||
nn.init.orthogonal_(m.weight.data, gain)
|
||||
space = m.kernel_size[0] * m.kernel_size[1]
|
||||
in_num = space * m.in_channels
|
||||
out_num = space * m.out_channels
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = 1.0 / denoms
|
||||
std = np.sqrt(scale) / 0.87962566103423978
|
||||
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.weight.data.fill_(1.0)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
m.bias.data.fill_(0.0)
|
||||
|
||||
def uniform_weight_init(given_scale):
|
||||
def f(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
in_num = m.in_features
|
||||
out_num = m.out_features
|
||||
denoms = (in_num + out_num) / 2.0
|
||||
scale = given_scale / denoms
|
||||
limit = np.sqrt(3 * scale)
|
||||
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
m.weight.data.fill_(1.0)
|
||||
if hasattr(m.bias, 'data'):
|
||||
m.bias.data.fill_(0.0)
|
||||
return f
|
||||
|
||||
def tensorstats(tensor, prefix=None):
|
||||
metrics = {
|
||||
'mean': to_np(torch.mean(tensor)),
|
||||
'std': to_np(torch.std(tensor)),
|
||||
'min': to_np(torch.min(tensor)),
|
||||
'max': to_np(torch.max(tensor)),
|
||||
}
|
||||
if prefix:
|
||||
metrics = {f'{prefix}_{k}': v for k, v in metrics.items()}
|
||||
return metrics
|
||||
|
||||
Reference in New Issue
Block a user