modified based on author's implementation

This commit is contained in:
NM512
2023-03-18 08:38:23 +09:00
parent a678a509b9
commit 6273444394
6 changed files with 371 additions and 229 deletions

156
tools.py
View File

@@ -17,6 +17,14 @@ from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
to_np = lambda x: x.detach().cpu().numpy()
def symlog(x):
return torch.sign(x) * torch.log(torch.abs(x) + 1.0)
def symexp(x):
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1.0)
class RequiresGrad:
def __init__(self, model):
@@ -269,11 +277,13 @@ class SampleDist:
class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
if logits is not None and probs is None and unimix_ratio > 0.0:
if logits is not None and unimix_ratio > 0.0:
probs = F.softmax(logits, dim=-1)
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
logits = None
super().__init__(logits=logits, probs=probs)
logits = torch.log(probs)
super().__init__(logits=logits, probs=None)
else:
super().__init__(logits=logits, probs=probs)
def mode(self):
_mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1])
@@ -290,42 +300,81 @@ class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
return sample
class TwoHotDist(torchd.one_hot_categorical.OneHotCategorical):
class TwoHotDistSymlog():
def __init__(self, logits=None, probs=None, unimix_ratio=0.0, device='cuda'):
if logits is not None and probs is None and unimix_ratio > 0.0:
probs = F.softmax(logits, dim=-1)
probs = probs * (1.0-unimix_ratio) + unimix_ratio / probs.shape[-1]
logits = None
super().__init__(logits=logits, probs=probs)
self.buckets = torch.linspace(-20.0, 20.0, steps=255).to(device)
def __init__(self, logits=None, low=-20.0, high=20.0, device='cuda'):
self.logits = logits
self.probs = torch.softmax(logits, -1)
self.buckets = torch.linspace(low, high, steps=255).to(device)
self.width = (self.buckets[-1] - self.buckets[0]) / 255
def mean(self):
print("mean called")
_mode = self.probs * self.buckets
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
def mode(self):
_mode = super().probs * self.buckets
return torch.sum(_mode, dim=-1, keepdim=True)
_mode = self.probs * self.buckets
return symexp(torch.sum(_mode, dim=-1, keepdim=True))
# Inside OneHotCategorical, log_prob is calculated using only max element in targets
def log_prob(self, x):
x = symlog(x)
# x(time, batch, 1)
x = (x - self.buckets[0]) / self.width
lower_indices = (x).to(torch.int64)
# lower_indices is idnside 0 ~ len(buckets)-2
lower_indices = torch.clip(lower_indices, max=len(self.buckets)-2)
# upper_indices is inside 1 ~ len(buckets)-1
upper_indices = lower_indices + 1
lower_weight = torch.abs(x - upper_indices).squeeze(-1)
upper_weight = torch.abs(x - lower_indices).squeeze(-1)
# (time, batch, 1) -> (time, batch, bucket_class)
lower_log_prob = super().log_prob(F.one_hot(lower_indices.squeeze(-1), num_classes=len(self.buckets)))
upper_log_prob = super().log_prob(F.one_hot(upper_indices.squeeze(-1), num_classes=len(self.buckets)))
below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) -1
above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1)
below = torch.clip(below, 0, len(self.buckets)-1)
above = torch.clip(above, 0, len(self.buckets)-1)
equal = (below == above)
# label = lower_log_prob * lower_weight + upper_log_prob * upper_weight
# # (time, batch, bucket_class) -> (time, batch)
# cross_entropy = torch.sum(torch.log(super().probs) * label, axis=-1)
dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
total = dist_to_below + dist_to_above
weight_below = dist_to_above / total
weight_above = dist_to_below / total
target = (
F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] +
F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None])
log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
target = target.squeeze(-2)
return lower_weight * lower_log_prob + upper_weight * upper_log_prob
return (target * log_pred).sum(-1)
def log_prob_target(self, target):
log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
return (target * log_pred).sum(-1)
class SymlogDist():
def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]):
self._mode = mode
self._dist = dist
self._agg = agg
self._tol = tol
self._dim_to_reduce = dim_to_reduce
def mode(self):
return symexp(self._mode)
def mean(self):
return symexp(self._mode)
def log_prob(self, value):
assert self._mode.shape == value.shape
if self._dist == 'mse':
distance = (self._mode - symlog(value)) ** 2.0
distance = torch.where(distance < self._tol, 0, distance)
elif self._dist == 'abs':
distance = torch.abs(self._mode - symlog(value))
distance = torch.where(distance < self._tol, 0, distance)
else:
raise NotImplementedError(self._dist)
if self._agg == 'mean':
loss = distance.mean(self._dim_to_reduce)
elif self._agg == 'sum':
loss = distance.sum(self._dim_to_reduce)
else:
raise NotImplementedError(self._agg)
return -loss
class ContDist:
@@ -438,6 +487,7 @@ def static_scan_for_lambda_return(fn, inputs, start):
indices = reversed(indices)
flag = True
for index in indices:
# (inputs, pcont) -> (inputs[index], pcont[index])
inp = lambda x: (_input[x] for _input in inputs)
last = fn(last, *inp(index))
if flag:
@@ -446,6 +496,7 @@ def static_scan_for_lambda_return(fn, inputs, start):
else:
outputs = torch.cat([outputs, last], dim=-1)
outputs = torch.reshape(outputs, [outputs.shape[0], outputs.shape[1], 1])
outputs = torch.flip(outputs, [1])
outputs = torch.unbind(outputs, dim=0)
return outputs
@@ -687,14 +738,53 @@ def schedule(string, step):
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight.data)
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(m.weight.data, gain)
space = m.kernel_size[0] * m.kernel_size[1]
in_num = space * m.in_channels
out_num = space * m.out_channels
denoms = (in_num + out_num) / 2.0
scale = 1.0 / denoms
std = np.sqrt(scale) / 0.87962566103423978
nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=- 2.0, b=2.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
m.bias.data.fill_(0.0)
def uniform_weight_init(given_scale):
def f(m):
if isinstance(m, nn.Linear):
in_num = m.in_features
out_num = m.out_features
denoms = (in_num + out_num) / 2.0
scale = given_scale / denoms
limit = np.sqrt(3 * scale)
nn.init.uniform_(m.weight.data, a=-limit, b=limit)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
elif isinstance(m, nn.LayerNorm):
m.weight.data.fill_(1.0)
if hasattr(m.bias, 'data'):
m.bias.data.fill_(0.0)
return f
def tensorstats(tensor, prefix=None):
metrics = {
'mean': to_np(torch.mean(tensor)),
'std': to_np(torch.std(tensor)),
'min': to_np(torch.min(tensor)),
'max': to_np(torch.max(tensor)),
}
if prefix:
metrics = {f'{prefix}_{k}': v for k, v in metrics.items()}
return metrics