diff --git a/models.py b/models.py index 5d27ff1..ae923e8 100644 --- a/models.py +++ b/models.py @@ -147,12 +147,15 @@ class WorldModel(nn.Module): model_loss = sum(scaled.values()) + kl_loss metrics = self._model_opt(torch.mean(model_loss), self.parameters()) - metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()}) + # Store scalar metrics to avoid keeping (batch,time) arrays until the next log step. + metrics.update( + {f"{name}_loss": to_np(torch.mean(loss)) for name, loss in losses.items()} + ) metrics["kl_free"] = kl_free metrics["dyn_scale"] = dyn_scale metrics["rep_scale"] = rep_scale - metrics["dyn_loss"] = to_np(dyn_loss) - metrics["rep_loss"] = to_np(rep_loss) + metrics["dyn_loss"] = to_np(torch.mean(dyn_loss)) + metrics["rep_loss"] = to_np(torch.mean(rep_loss)) metrics["kl"] = to_np(torch.mean(kl_value)) with torch.cuda.amp.autocast(self._use_amp): metrics["prior_ent"] = to_np(