From 62be41ab581b803845896f4b84e6afcefe5608b9 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Thu, 10 Apr 2025 00:32:13 -0700 Subject: [PATCH] experimental changes to termination prediction --- tdmpc2/common/world_model.py | 13 ++++++++++--- tdmpc2/config.yaml | 2 +- tdmpc2/tdmpc2.py | 11 ++++++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 91e581e..d0040b8 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -127,14 +127,16 @@ class WorldModel(nn.Module): z = torch.cat([z, a], dim=-1) return self._reward(z) - def termination(self, z, task): + def termination(self, z, task, sigmoid=True): """ Predicts termination signal. """ assert task is None if self.cfg.multitask: z = self.task_emb(z, task) - return torch.sigmoid(self._termination(z)) + if sigmoid: + return torch.sigmoid(self._termination(z)) + return self._termination(z) def pi(self, z, task): """ @@ -184,11 +186,12 @@ class WorldModel(nn.Module): `return_type` can be one of [`min`, `avg`, `all`]: - `min`: return the minimum of two randomly subsampled Q-values. - `avg`: return the average of two randomly subsampled Q-values. + - 'min-all': return the minimum of all Q-values. - 'avg-all': return the average of all Q-values. - `all`: return all Q-values. `target` specifies whether to use the target Q-networks or not. """ - assert return_type in {'min', 'avg', 'avg-all', 'all'} + assert return_type in {'min', 'avg', 'min-all', 'avg-all', 'all'} if self.cfg.multitask: z = self.task_emb(z, task) @@ -208,6 +211,10 @@ class WorldModel(nn.Module): if return_type == 'avg-all': Q = math.two_hot_inv(out, self.cfg) return Q.mean(0) + + if return_type == 'min-all': + Q = math.two_hot_inv(out, self.cfg) + return Q.min(0).values qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2] Q = math.two_hot_inv(out[qidx], self.cfg) diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index 6d3510c..a077bb6 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -16,7 +16,7 @@ steps: 10_000_000 batch_size: 256 reward_coef: 0.1 value_coef: 0.1 -termination_coef: 20 +termination_coef: 1 consistency_coef: 20 rho: 0.5 lr: 3e-4 diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 8ab3c2b..61de435 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -254,7 +254,7 @@ class TDMPC2(torch.nn.Module): """ action, _ = self.model.pi(next_z, task) discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount - # return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='avg-all', target=True) + # return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min-all', target=True) return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True) def _update(self, obs, action, reward, terminated, task=None): @@ -280,7 +280,7 @@ class TDMPC2(torch.nn.Module): _zs = zs[:-1] qs = self.model.Q(_zs, action, task, return_type='all') reward_preds = self.model.reward(_zs, action, task) - termination_pred = self.model.termination(zs[-1], task) + termination_pred = self.model.termination(zs[1:], task, sigmoid=False) # Compute losses reward_loss, value_loss = 0, 0 @@ -291,7 +291,12 @@ class TDMPC2(torch.nn.Module): consistency_loss = consistency_loss / self.cfg.horizon reward_loss = reward_loss / self.cfg.horizon - termination_loss = F.binary_cross_entropy(termination_pred, terminated[-1]) + # termination_loss = F.binary_cross_entropy(termination_pred, terminated) + termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated) + # termination_loss = F.binary_cross_entropy(termination_pred, terminated, reduction='none') + # weighted mean over time, with last time step weighted as much as the rest combined + # termination_loss[:-1] = termination_loss[:-1] / (self.cfg.horizon**2) + # termination_loss = termination_loss.mean() value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q) total_loss = ( self.cfg.consistency_coef * consistency_loss +