disable uncertainty estimation when coef=0

This commit is contained in:
Nicklas Hansen
2024-01-04 19:39:44 -08:00
parent 392b16ac89
commit 188bd201aa

View File

@@ -94,6 +94,8 @@ class TDMPC2:
@torch.no_grad() @torch.no_grad()
def _estimate_uncertainty(self, z, task): def _estimate_uncertainty(self, z, task):
"""Estimates epistemic uncertainty, normalized by predicted value.""" """Estimates epistemic uncertainty, normalized by predicted value."""
if self.cfg.uncertainty_coef == 0:
return 0
qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg) qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg)
return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef