diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index c70ee2e..797c58d 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -94,6 +94,8 @@ class TDMPC2: @torch.no_grad() def _estimate_uncertainty(self, z, task): """Estimates epistemic uncertainty, normalized by predicted value.""" + if self.cfg.uncertainty_coef == 0: + return 0 qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg) return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef