add uncertainty regularization
This commit is contained in:
@@ -38,6 +38,7 @@ horizon: 3
|
|||||||
min_std: 0.05
|
min_std: 0.05
|
||||||
max_std: 2
|
max_std: 2
|
||||||
temperature: 0.5
|
temperature: 0.5
|
||||||
|
uncertainty_coef: 0
|
||||||
|
|
||||||
# actor
|
# actor
|
||||||
log_std_min: -10
|
log_std_min: -10
|
||||||
|
|||||||
@@ -90,6 +90,12 @@ class TDMPC2:
|
|||||||
else:
|
else:
|
||||||
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
||||||
return a.cpu()
|
return a.cpu()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _estimate_uncertainty(self, z, task):
|
||||||
|
"""Estimates epistemic uncertainty, normalized by predicted value."""
|
||||||
|
qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg)
|
||||||
|
return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _estimate_value(self, z, actions, task):
|
def _estimate_value(self, z, actions, task):
|
||||||
@@ -98,9 +104,10 @@ class TDMPC2:
|
|||||||
for t in range(self.cfg.horizon):
|
for t in range(self.cfg.horizon):
|
||||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||||
z = self.model.next(z, actions[t], task)
|
z = self.model.next(z, actions[t], task)
|
||||||
G += discount * reward
|
G += discount * (reward - self._estimate_uncertainty(z, task))
|
||||||
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
terminal_value = self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||||
|
return G + discount * (terminal_value - self._estimate_uncertainty(z, task))
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def plan(self, z, t0=False, eval_mode=False, task=None):
|
def plan(self, z, t0=False, eval_mode=False, task=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user