fix
This commit is contained in:
@@ -257,16 +257,14 @@ class TDMPC2:
|
||||
terminated_pred = self.model.terminated(zs[-1], task)
|
||||
|
||||
# Compute losses
|
||||
reward_loss, terminated_loss, value_loss = 0, 0, 0
|
||||
reward_loss, value_loss = 0, 0
|
||||
for t in range(self.cfg.horizon):
|
||||
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
|
||||
terminated_loss += F.binary_cross_entropy(terminated_pred[t], terminated[t]) * self.cfg.rho**t
|
||||
for q in range(self.cfg.num_q):
|
||||
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
|
||||
terminated_loss = F.binary_cross_entropy(terminated_pred, terminated)
|
||||
consistency_loss *= (1/self.cfg.horizon)
|
||||
reward_loss *= (1/self.cfg.horizon)
|
||||
terminated_loss *= (1/self.cfg.horizon)
|
||||
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss +
|
||||
|
||||
Reference in New Issue
Block a user