diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index b221364..4babde7 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -154,12 +154,13 @@ class WorldModel(nn.Module): action = mean + eps * log_std.exp() mean, action, log_prob = math.squash(mean, action, log_prob) + entropy_scale = scaled_log_prob / (log_prob + 1e-8) info = TensorDict({ "mean": mean, "log_std": log_std, "action_prob": 1., "entropy": -log_prob, - "scaled_entropy": -scaled_log_prob, + "scaled_entropy": -log_prob * entropy_scale, }) return action, info