refactor pi outputs

This commit is contained in:
Nicklas Hansen
2024-12-25 12:02:33 -08:00
parent 804f9b3949
commit db1865334e

View File

@@ -154,13 +154,12 @@ class WorldModel(nn.Module):
action = mean + eps * log_std.exp()
mean, action, log_prob = math.squash(mean, action, log_prob)
entropy_scale = scaled_log_prob / (log_prob + 1e-8)
info = TensorDict({
"mean": mean,
"log_std": log_std,
"action_prob": 1.,
"entropy": -log_prob,
"scaled_entropy": -log_prob * entropy_scale,
"scaled_entropy": -scaled_log_prob,
})
return action, info