experimental changes to termination prediction

This commit is contained in:
Nicklas Hansen
2025-04-10 00:32:13 -07:00
parent c95b755655
commit 62be41ab58
3 changed files with 19 additions and 7 deletions

View File

@@ -127,14 +127,16 @@ class WorldModel(nn.Module):
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
return self._reward(z) return self._reward(z)
def termination(self, z, task): def termination(self, z, task, sigmoid=True):
""" """
Predicts termination signal. Predicts termination signal.
""" """
assert task is None assert task is None
if self.cfg.multitask: if self.cfg.multitask:
z = self.task_emb(z, task) z = self.task_emb(z, task)
return torch.sigmoid(self._termination(z)) if sigmoid:
return torch.sigmoid(self._termination(z))
return self._termination(z)
def pi(self, z, task): def pi(self, z, task):
""" """
@@ -184,11 +186,12 @@ class WorldModel(nn.Module):
`return_type` can be one of [`min`, `avg`, `all`]: `return_type` can be one of [`min`, `avg`, `all`]:
- `min`: return the minimum of two randomly subsampled Q-values. - `min`: return the minimum of two randomly subsampled Q-values.
- `avg`: return the average of two randomly subsampled Q-values. - `avg`: return the average of two randomly subsampled Q-values.
- 'min-all': return the minimum of all Q-values.
- 'avg-all': return the average of all Q-values. - 'avg-all': return the average of all Q-values.
- `all`: return all Q-values. - `all`: return all Q-values.
`target` specifies whether to use the target Q-networks or not. `target` specifies whether to use the target Q-networks or not.
""" """
assert return_type in {'min', 'avg', 'avg-all', 'all'} assert return_type in {'min', 'avg', 'min-all', 'avg-all', 'all'}
if self.cfg.multitask: if self.cfg.multitask:
z = self.task_emb(z, task) z = self.task_emb(z, task)
@@ -208,6 +211,10 @@ class WorldModel(nn.Module):
if return_type == 'avg-all': if return_type == 'avg-all':
Q = math.two_hot_inv(out, self.cfg) Q = math.two_hot_inv(out, self.cfg)
return Q.mean(0) return Q.mean(0)
if return_type == 'min-all':
Q = math.two_hot_inv(out, self.cfg)
return Q.min(0).values
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2] qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
Q = math.two_hot_inv(out[qidx], self.cfg) Q = math.two_hot_inv(out[qidx], self.cfg)

View File

@@ -16,7 +16,7 @@ steps: 10_000_000
batch_size: 256 batch_size: 256
reward_coef: 0.1 reward_coef: 0.1
value_coef: 0.1 value_coef: 0.1
termination_coef: 20 termination_coef: 1
consistency_coef: 20 consistency_coef: 20
rho: 0.5 rho: 0.5
lr: 3e-4 lr: 3e-4

View File

@@ -254,7 +254,7 @@ class TDMPC2(torch.nn.Module):
""" """
action, _ = self.model.pi(next_z, task) action, _ = self.model.pi(next_z, task)
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='avg-all', target=True) # return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min-all', target=True)
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True) return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
def _update(self, obs, action, reward, terminated, task=None): def _update(self, obs, action, reward, terminated, task=None):
@@ -280,7 +280,7 @@ class TDMPC2(torch.nn.Module):
_zs = zs[:-1] _zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all') qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task) reward_preds = self.model.reward(_zs, action, task)
termination_pred = self.model.termination(zs[-1], task) termination_pred = self.model.termination(zs[1:], task, sigmoid=False)
# Compute losses # Compute losses
reward_loss, value_loss = 0, 0 reward_loss, value_loss = 0, 0
@@ -291,7 +291,12 @@ class TDMPC2(torch.nn.Module):
consistency_loss = consistency_loss / self.cfg.horizon consistency_loss = consistency_loss / self.cfg.horizon
reward_loss = reward_loss / self.cfg.horizon reward_loss = reward_loss / self.cfg.horizon
termination_loss = F.binary_cross_entropy(termination_pred, terminated[-1]) # termination_loss = F.binary_cross_entropy(termination_pred, terminated)
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
# termination_loss = F.binary_cross_entropy(termination_pred, terminated, reduction='none')
# weighted mean over time, with last time step weighted as much as the rest combined
# termination_loss[:-1] = termination_loss[:-1] / (self.cfg.horizon**2)
# termination_loss = termination_loss.mean()
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q) value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
total_loss = ( total_loss = (
self.cfg.consistency_coef * consistency_loss + self.cfg.consistency_coef * consistency_loss +