disable uncertainty estimation when coef=0

This commit is contained in:
Nicklas Hansen
2024-01-04 19:39:44 -08:00
parent 194c92331c
commit e5c9029c86

View File

@@ -91,6 +91,8 @@ class TDMPC2:
@torch.no_grad()
def _estimate_uncertainty(self, z, task):
"""Estimates epistemic uncertainty, normalized by predicted value."""
if self.cfg.uncertainty_coef == 0:
return 0
qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg)
return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef