apply mean to log items for consistency

This commit is contained in:
NM512
2026-02-21 20:58:43 +09:00
parent 7433d1e877
commit be5e5ecf40

View File

@@ -147,12 +147,15 @@ class WorldModel(nn.Module):
model_loss = sum(scaled.values()) + kl_loss
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
# Store scalar metrics to avoid keeping (batch,time) arrays until the next log step.
metrics.update(
{f"{name}_loss": to_np(torch.mean(loss)) for name, loss in losses.items()}
)
metrics["kl_free"] = kl_free
metrics["dyn_scale"] = dyn_scale
metrics["rep_scale"] = rep_scale
metrics["dyn_loss"] = to_np(dyn_loss)
metrics["rep_loss"] = to_np(rep_loss)
metrics["dyn_loss"] = to_np(torch.mean(dyn_loss))
metrics["rep_loss"] = to_np(torch.mean(rep_loss))
metrics["kl"] = to_np(torch.mean(kl_value))
with torch.cuda.amp.autocast(self._use_amp):
metrics["prior_ent"] = to_np(