factor pi outputs

This commit is contained in:
Nicklas Hansen
2024-12-25 12:08:07 -08:00
parent db1865334e
commit e452ca7539

View File

@@ -154,12 +154,13 @@ class WorldModel(nn.Module):
action = mean + eps * log_std.exp()
mean, action, log_prob = math.squash(mean, action, log_prob)
entropy_scale = scaled_log_prob / (log_prob + 1e-8)
info = TensorDict({
"mean": mean,
"log_std": log_std,
"action_prob": 1.,
"entropy": -log_prob,
"scaled_entropy": -scaled_log_prob,
"scaled_entropy": -log_prob * entropy_scale,
})
return action, info