avoid ".to(device)"
This commit is contained in:
7
tools.py
7
tools.py
@@ -461,7 +461,7 @@ class DiscDist:
|
||||
):
|
||||
self.logits = logits
|
||||
self.probs = torch.softmax(logits, -1)
|
||||
self.buckets = torch.linspace(low, high, steps=255).to(device)
|
||||
self.buckets = torch.linspace(low, high, steps=255, device=device)
|
||||
self.width = (self.buckets[-1] - self.buckets[0]) / 255
|
||||
self.transfwd = transfwd
|
||||
self.transbwd = transbwd
|
||||
@@ -624,8 +624,7 @@ class UnnormalizedHuber(torchd.normal.Normal):
|
||||
|
||||
def log_prob(self, event):
|
||||
return -(
|
||||
torch.sqrt((event - self.mean) ** 2 + self._threshold**2)
|
||||
- self._threshold
|
||||
torch.sqrt((event - self.mean) ** 2 + self._threshold**2) - self._threshold
|
||||
)
|
||||
|
||||
def mode(self):
|
||||
@@ -762,7 +761,7 @@ class Optimizer:
|
||||
self._scaler.update()
|
||||
# self._opt.step()
|
||||
self._opt.zero_grad()
|
||||
metrics[f"{self._name}_grad_norm"] = norm.item()
|
||||
metrics[f"{self._name}_grad_norm"] = to_np(norm)
|
||||
return metrics
|
||||
|
||||
def _apply_weight_decay(self, varibs):
|
||||
|
||||
Reference in New Issue
Block a user