diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 4babde7..b221364 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -154,13 +154,12 @@ class WorldModel(nn.Module): action = mean + eps * log_std.exp() mean, action, log_prob = math.squash(mean, action, log_prob) - entropy_scale = scaled_log_prob / (log_prob + 1e-8) info = TensorDict({ "mean": mean, "log_std": log_std, "action_prob": 1., "entropy": -log_prob, - "scaled_entropy": -log_prob * entropy_scale, + "scaled_entropy": -scaled_log_prob, }) return action, info