apply mean to log items for consistency
This commit is contained in:
@@ -147,12 +147,15 @@ class WorldModel(nn.Module):
|
||||
model_loss = sum(scaled.values()) + kl_loss
|
||||
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
|
||||
|
||||
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
|
||||
# Store scalar metrics to avoid keeping (batch,time) arrays until the next log step.
|
||||
metrics.update(
|
||||
{f"{name}_loss": to_np(torch.mean(loss)) for name, loss in losses.items()}
|
||||
)
|
||||
metrics["kl_free"] = kl_free
|
||||
metrics["dyn_scale"] = dyn_scale
|
||||
metrics["rep_scale"] = rep_scale
|
||||
metrics["dyn_loss"] = to_np(dyn_loss)
|
||||
metrics["rep_loss"] = to_np(rep_loss)
|
||||
metrics["dyn_loss"] = to_np(torch.mean(dyn_loss))
|
||||
metrics["rep_loss"] = to_np(torch.mean(rep_loss))
|
||||
metrics["kl"] = to_np(torch.mean(kl_value))
|
||||
with torch.cuda.amp.autocast(self._use_amp):
|
||||
metrics["prior_ent"] = to_np(
|
||||
|
||||
Reference in New Issue
Block a user