experimental changes to termination prediction

This commit is contained in:
Nicklas Hansen
2025-04-10 00:32:13 -07:00
parent c95b755655
commit 62be41ab58
3 changed files with 19 additions and 7 deletions

View File

@@ -127,14 +127,16 @@ class WorldModel(nn.Module):
z = torch.cat([z, a], dim=-1)
return self._reward(z)
def termination(self, z, task):
def termination(self, z, task, sigmoid=True):
"""
Predicts termination signal.
"""
assert task is None
if self.cfg.multitask:
z = self.task_emb(z, task)
return torch.sigmoid(self._termination(z))
if sigmoid:
return torch.sigmoid(self._termination(z))
return self._termination(z)
def pi(self, z, task):
"""
@@ -184,11 +186,12 @@ class WorldModel(nn.Module):
`return_type` can be one of [`min`, `avg`, `all`]:
- `min`: return the minimum of two randomly subsampled Q-values.
- `avg`: return the average of two randomly subsampled Q-values.
- 'min-all': return the minimum of all Q-values.
- 'avg-all': return the average of all Q-values.
- `all`: return all Q-values.
`target` specifies whether to use the target Q-networks or not.
"""
assert return_type in {'min', 'avg', 'avg-all', 'all'}
assert return_type in {'min', 'avg', 'min-all', 'avg-all', 'all'}
if self.cfg.multitask:
z = self.task_emb(z, task)
@@ -208,6 +211,10 @@ class WorldModel(nn.Module):
if return_type == 'avg-all':
Q = math.two_hot_inv(out, self.cfg)
return Q.mean(0)
if return_type == 'min-all':
Q = math.two_hot_inv(out, self.cfg)
return Q.min(0).values
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
Q = math.two_hot_inv(out[qidx], self.cfg)

View File

@@ -16,7 +16,7 @@ steps: 10_000_000
batch_size: 256
reward_coef: 0.1
value_coef: 0.1
termination_coef: 20
termination_coef: 1
consistency_coef: 20
rho: 0.5
lr: 3e-4

View File

@@ -254,7 +254,7 @@ class TDMPC2(torch.nn.Module):
"""
action, _ = self.model.pi(next_z, task)
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='avg-all', target=True)
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min-all', target=True)
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
def _update(self, obs, action, reward, terminated, task=None):
@@ -280,7 +280,7 @@ class TDMPC2(torch.nn.Module):
_zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task)
termination_pred = self.model.termination(zs[-1], task)
termination_pred = self.model.termination(zs[1:], task, sigmoid=False)
# Compute losses
reward_loss, value_loss = 0, 0
@@ -291,7 +291,12 @@ class TDMPC2(torch.nn.Module):
consistency_loss = consistency_loss / self.cfg.horizon
reward_loss = reward_loss / self.cfg.horizon
termination_loss = F.binary_cross_entropy(termination_pred, terminated[-1])
# termination_loss = F.binary_cross_entropy(termination_pred, terminated)
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
# termination_loss = F.binary_cross_entropy(termination_pred, terminated, reduction='none')
# weighted mean over time, with last time step weighted as much as the rest combined
# termination_loss[:-1] = termination_loss[:-1] / (self.cfg.horizon**2)
# termination_loss = termination_loss.mean()
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
total_loss = (
self.cfg.consistency_coef * consistency_loss +