From e452ca7539da05d2b4628ace953a048b46bb501b Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Wed, 25 Dec 2024 12:08:07 -0800 Subject: [PATCH] factor pi outputs --- tdmpc2/common/world_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index b221364..4babde7 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -154,12 +154,13 @@ class WorldModel(nn.Module): action = mean + eps * log_std.exp() mean, action, log_prob = math.squash(mean, action, log_prob) + entropy_scale = scaled_log_prob / (log_prob + 1e-8) info = TensorDict({ "mean": mean, "log_std": log_std, "action_prob": 1., "entropy": -log_prob, - "scaled_entropy": -scaled_log_prob, + "scaled_entropy": -log_prob * entropy_scale, }) return action, info