avoid ".to(device)"

This commit is contained in:
NM512
2024-09-28 07:58:15 +09:00
parent 669b7e1b43
commit 7433d1e877
5 changed files with 37 additions and 33 deletions

View File

@@ -461,7 +461,7 @@ class DiscDist:
):
self.logits = logits
self.probs = torch.softmax(logits, -1)
self.buckets = torch.linspace(low, high, steps=255).to(device)
self.buckets = torch.linspace(low, high, steps=255, device=device)
self.width = (self.buckets[-1] - self.buckets[0]) / 255
self.transfwd = transfwd
self.transbwd = transbwd
@@ -624,8 +624,7 @@ class UnnormalizedHuber(torchd.normal.Normal):
def log_prob(self, event):
return -(
torch.sqrt((event - self.mean) ** 2 + self._threshold**2)
- self._threshold
torch.sqrt((event - self.mean) ** 2 + self._threshold**2) - self._threshold
)
def mode(self):
@@ -762,7 +761,7 @@ class Optimizer:
self._scaler.update()
# self._opt.step()
self._opt.zero_grad()
metrics[f"{self._name}_grad_norm"] = norm.item()
metrics[f"{self._name}_grad_norm"] = to_np(norm)
return metrics
def _apply_weight_decay(self, varibs):