add uncertainty regularization

This commit is contained in:
Nicklas Hansen
2024-01-03 18:11:32 -08:00
parent 0f3bc77011
commit 392b16ac89
2 changed files with 10 additions and 2 deletions

View File

@@ -38,6 +38,7 @@ horizon: 3
min_std: 0.05
max_std: 2
temperature: 0.5
uncertainty_coef: 0
# actor
log_std_min: -10

View File

@@ -91,6 +91,12 @@ class TDMPC2:
a = self.model.pi(z, task)[int(not eval_mode)][0]
return a.cpu()
@torch.no_grad()
def _estimate_uncertainty(self, z, task):
"""Estimates epistemic uncertainty, normalized by predicted value."""
qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg)
return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef
@torch.no_grad()
def _estimate_value(self, z, actions, task):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
@@ -98,9 +104,10 @@ class TDMPC2:
for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task)
G += discount * reward
G += discount * (reward - self._estimate_uncertainty(z, task))
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
terminal_value = self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
return G + discount * (terminal_value - self._estimate_uncertainty(z, task))
@torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None):