diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index b720923..b529585 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -38,6 +38,7 @@ horizon: 3 min_std: 0.05 max_std: 2 temperature: 0.5 +uncertainty_coef: 0 # actor log_std_min: -10 diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 3925359..2afd8dd 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -87,6 +87,12 @@ class TDMPC2: z = self.model.encode(obs, task) a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) return a.cpu() + + @torch.no_grad() + def _estimate_uncertainty(self, z, task): + """Estimates epistemic uncertainty, normalized by predicted value.""" + qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg) + return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef @torch.no_grad() def _estimate_value(self, z, actions, task): @@ -95,9 +101,10 @@ class TDMPC2: for t in range(self.cfg.horizon): reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) z = self.model.next(z, actions[t], task) - G += discount * reward + G += discount * (reward - self._estimate_uncertainty(z, task)) discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount - return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') + terminal_value = self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') + return G + discount * (terminal_value - self._estimate_uncertainty(z, task)) @torch.no_grad() def plan(self, z, t0=False, eval_mode=False, task=None):