From 392b16ac893705edbbf984991b70ce303a8d761c Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Wed, 3 Jan 2024 18:11:32 -0800 Subject: [PATCH] add uncertainty regularization --- tdmpc2/config.yaml | 1 + tdmpc2/tdmpc2.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index b720923..b529585 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -38,6 +38,7 @@ horizon: 3 min_std: 0.05 max_std: 2 temperature: 0.5 +uncertainty_coef: 0 # actor log_std_min: -10 diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 9d49cf8..c70ee2e 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -90,6 +90,12 @@ class TDMPC2: else: a = self.model.pi(z, task)[int(not eval_mode)][0] return a.cpu() + + @torch.no_grad() + def _estimate_uncertainty(self, z, task): + """Estimates epistemic uncertainty, normalized by predicted value.""" + qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg) + return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef @torch.no_grad() def _estimate_value(self, z, actions, task): @@ -98,9 +104,10 @@ class TDMPC2: for t in range(self.cfg.horizon): reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) z = self.model.next(z, actions[t], task) - G += discount * reward + G += discount * (reward - self._estimate_uncertainty(z, task)) discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount - return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') + terminal_value = self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') + return G + discount * (terminal_value - self._estimate_uncertainty(z, task)) @torch.no_grad() def plan(self, z, t0=False, eval_mode=False, task=None):