apply mean to log items for consistency
This commit is contained in:
@@ -147,12 +147,15 @@ class WorldModel(nn.Module):
|
|||||||
model_loss = sum(scaled.values()) + kl_loss
|
model_loss = sum(scaled.values()) + kl_loss
|
||||||
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
|
metrics = self._model_opt(torch.mean(model_loss), self.parameters())
|
||||||
|
|
||||||
metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
|
# Store scalar metrics to avoid keeping (batch,time) arrays until the next log step.
|
||||||
|
metrics.update(
|
||||||
|
{f"{name}_loss": to_np(torch.mean(loss)) for name, loss in losses.items()}
|
||||||
|
)
|
||||||
metrics["kl_free"] = kl_free
|
metrics["kl_free"] = kl_free
|
||||||
metrics["dyn_scale"] = dyn_scale
|
metrics["dyn_scale"] = dyn_scale
|
||||||
metrics["rep_scale"] = rep_scale
|
metrics["rep_scale"] = rep_scale
|
||||||
metrics["dyn_loss"] = to_np(dyn_loss)
|
metrics["dyn_loss"] = to_np(torch.mean(dyn_loss))
|
||||||
metrics["rep_loss"] = to_np(rep_loss)
|
metrics["rep_loss"] = to_np(torch.mean(rep_loss))
|
||||||
metrics["kl"] = to_np(torch.mean(kl_value))
|
metrics["kl"] = to_np(torch.mean(kl_value))
|
||||||
with torch.cuda.amp.autocast(self._use_amp):
|
with torch.cuda.amp.autocast(self._use_amp):
|
||||||
metrics["prior_ent"] = to_np(
|
metrics["prior_ent"] = to_np(
|
||||||
|
|||||||
Reference in New Issue
Block a user