diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 2afd8dd..43515c6 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -91,6 +91,8 @@ class TDMPC2: @torch.no_grad() def _estimate_uncertainty(self, z, task): """Estimates epistemic uncertainty, normalized by predicted value.""" + if self.cfg.uncertainty_coef == 0: + return 0 qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg) return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef