modified loss calculation
This commit is contained in:
14
tools.py
14
tools.py
@@ -338,7 +338,7 @@ def sample_episodes(episodes, length, seed=0):
|
||||
if not ret:
|
||||
index = int(np_random.randint(0, total - 1))
|
||||
ret = {
|
||||
k: v[index : min(index + length, total)]
|
||||
k: v[index : min(index + length, total)].copy()
|
||||
for k, v in episode.items()
|
||||
if "log_" not in k
|
||||
}
|
||||
@@ -350,7 +350,7 @@ def sample_episodes(episodes, length, seed=0):
|
||||
possible = length - size
|
||||
ret = {
|
||||
k: np.append(
|
||||
ret[k], v[index : min(index + possible, total)], axis=0
|
||||
ret[k], v[index : min(index + possible, total)].copy(), axis=0
|
||||
)
|
||||
for k, v in episode.items()
|
||||
if "log_" not in k
|
||||
@@ -482,6 +482,7 @@ class DiscDist:
|
||||
above = len(self.buckets) - torch.sum(
|
||||
(self.buckets > x[..., None]).to(torch.int32), dim=-1
|
||||
)
|
||||
# this is implemented using clip at the original repo as the gradients are not backpropagated for the out of limits.
|
||||
below = torch.clip(below, 0, len(self.buckets) - 1)
|
||||
above = torch.clip(above, 0, len(self.buckets) - 1)
|
||||
equal = below == above
|
||||
@@ -606,7 +607,7 @@ class Bernoulli:
|
||||
log_probs0 = -F.softplus(_logits)
|
||||
log_probs1 = -F.softplus(-_logits)
|
||||
|
||||
return log_probs0 * (1 - x) + log_probs1 * x
|
||||
return torch.sum(log_probs0 * (1 - x) + log_probs1 * x, -1)
|
||||
|
||||
|
||||
class UnnormalizedHuber(torchd.normal.Normal):
|
||||
@@ -739,11 +740,12 @@ class Optimizer:
|
||||
}[opt]()
|
||||
self._scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
|
||||
|
||||
def __call__(self, loss, params, retain_graph=False):
|
||||
def __call__(self, loss, params, retain_graph=True):
|
||||
assert len(loss.shape) == 0, loss.shape
|
||||
metrics = {}
|
||||
metrics[f"{self._name}_loss"] = loss.detach().cpu().numpy()
|
||||
self._scaler.scale(loss).backward()
|
||||
self._opt.zero_grad()
|
||||
self._scaler.scale(loss).backward(retain_graph=retain_graph)
|
||||
self._scaler.unscale_(self._opt)
|
||||
# loss.backward(retain_graph=retain_graph)
|
||||
norm = torch.nn.utils.clip_grad_norm_(params, self._clip)
|
||||
@@ -1001,11 +1003,9 @@ def recursively_collect_optim_state_dict(
|
||||
|
||||
|
||||
def recursively_load_optim_state_dict(obj, optimizers_state_dicts):
|
||||
print(optimizers_state_dicts)
|
||||
for path, state_dict in optimizers_state_dicts.items():
|
||||
keys = path.split(".")
|
||||
obj_now = obj
|
||||
for key in keys:
|
||||
obj_now = getattr(obj_now, key)
|
||||
print(keys)
|
||||
obj_now.load_state_dict(state_dict)
|
||||
|
||||
Reference in New Issue
Block a user