From ff02f41e73cce8b6ef9eef99c3830e131fbaf97f Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Mon, 8 Jan 2024 17:18:22 -0800 Subject: [PATCH] fix --- tdmpc2/tdmpc2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 0ae722f..c442f4e 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -257,16 +257,14 @@ class TDMPC2: terminated_pred = self.model.terminated(zs[-1], task) # Compute losses - reward_loss, terminated_loss, value_loss = 0, 0, 0 + reward_loss, value_loss = 0, 0 for t in range(self.cfg.horizon): reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t - terminated_loss += F.binary_cross_entropy(terminated_pred[t], terminated[t]) * self.cfg.rho**t for q in range(self.cfg.num_q): value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t terminated_loss = F.binary_cross_entropy(terminated_pred, terminated) consistency_loss *= (1/self.cfg.horizon) reward_loss *= (1/self.cfg.horizon) - terminated_loss *= (1/self.cfg.horizon) value_loss *= (1/(self.cfg.horizon * self.cfg.num_q)) total_loss = ( self.cfg.consistency_coef * consistency_loss +