refactor pi outputs
This commit is contained in:
@@ -154,13 +154,12 @@ class WorldModel(nn.Module):
|
|||||||
action = mean + eps * log_std.exp()
|
action = mean + eps * log_std.exp()
|
||||||
mean, action, log_prob = math.squash(mean, action, log_prob)
|
mean, action, log_prob = math.squash(mean, action, log_prob)
|
||||||
|
|
||||||
entropy_scale = scaled_log_prob / (log_prob + 1e-8)
|
|
||||||
info = TensorDict({
|
info = TensorDict({
|
||||||
"mean": mean,
|
"mean": mean,
|
||||||
"log_std": log_std,
|
"log_std": log_std,
|
||||||
"action_prob": 1.,
|
"action_prob": 1.,
|
||||||
"entropy": -log_prob,
|
"entropy": -log_prob,
|
||||||
"scaled_entropy": -log_prob * entropy_scale,
|
"scaled_entropy": -scaled_log_prob,
|
||||||
})
|
})
|
||||||
return action, info
|
return action, info
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user