This commit is contained in:
Nicklas Hansen
2024-01-08 17:18:22 -08:00
parent e86c343a67
commit ff02f41e73

View File

@@ -257,16 +257,14 @@ class TDMPC2:
terminated_pred = self.model.terminated(zs[-1], task)
# Compute losses
reward_loss, terminated_loss, value_loss = 0, 0, 0
reward_loss, value_loss = 0, 0
for t in range(self.cfg.horizon):
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
terminated_loss += F.binary_cross_entropy(terminated_pred[t], terminated[t]) * self.cfg.rho**t
for q in range(self.cfg.num_q):
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
terminated_loss = F.binary_cross_entropy(terminated_pred, terminated)
consistency_loss *= (1/self.cfg.horizon)
reward_loss *= (1/self.cfg.horizon)
terminated_loss *= (1/self.cfg.horizon)
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
total_loss = (
self.cfg.consistency_coef * consistency_loss +