experimental changes to termination prediction
This commit is contained in:
@@ -127,14 +127,16 @@ class WorldModel(nn.Module):
|
||||
z = torch.cat([z, a], dim=-1)
|
||||
return self._reward(z)
|
||||
|
||||
def termination(self, z, task):
|
||||
def termination(self, z, task, sigmoid=True):
|
||||
"""
|
||||
Predicts termination signal.
|
||||
"""
|
||||
assert task is None
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
return torch.sigmoid(self._termination(z))
|
||||
if sigmoid:
|
||||
return torch.sigmoid(self._termination(z))
|
||||
return self._termination(z)
|
||||
|
||||
def pi(self, z, task):
|
||||
"""
|
||||
@@ -184,11 +186,12 @@ class WorldModel(nn.Module):
|
||||
`return_type` can be one of [`min`, `avg`, `all`]:
|
||||
- `min`: return the minimum of two randomly subsampled Q-values.
|
||||
- `avg`: return the average of two randomly subsampled Q-values.
|
||||
- 'min-all': return the minimum of all Q-values.
|
||||
- 'avg-all': return the average of all Q-values.
|
||||
- `all`: return all Q-values.
|
||||
`target` specifies whether to use the target Q-networks or not.
|
||||
"""
|
||||
assert return_type in {'min', 'avg', 'avg-all', 'all'}
|
||||
assert return_type in {'min', 'avg', 'min-all', 'avg-all', 'all'}
|
||||
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
@@ -208,6 +211,10 @@ class WorldModel(nn.Module):
|
||||
if return_type == 'avg-all':
|
||||
Q = math.two_hot_inv(out, self.cfg)
|
||||
return Q.mean(0)
|
||||
|
||||
if return_type == 'min-all':
|
||||
Q = math.two_hot_inv(out, self.cfg)
|
||||
return Q.min(0).values
|
||||
|
||||
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
|
||||
Q = math.two_hot_inv(out[qidx], self.cfg)
|
||||
|
||||
@@ -16,7 +16,7 @@ steps: 10_000_000
|
||||
batch_size: 256
|
||||
reward_coef: 0.1
|
||||
value_coef: 0.1
|
||||
termination_coef: 20
|
||||
termination_coef: 1
|
||||
consistency_coef: 20
|
||||
rho: 0.5
|
||||
lr: 3e-4
|
||||
|
||||
@@ -254,7 +254,7 @@ class TDMPC2(torch.nn.Module):
|
||||
"""
|
||||
action, _ = self.model.pi(next_z, task)
|
||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='avg-all', target=True)
|
||||
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min-all', target=True)
|
||||
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
|
||||
|
||||
def _update(self, obs, action, reward, terminated, task=None):
|
||||
@@ -280,7 +280,7 @@ class TDMPC2(torch.nn.Module):
|
||||
_zs = zs[:-1]
|
||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||
reward_preds = self.model.reward(_zs, action, task)
|
||||
termination_pred = self.model.termination(zs[-1], task)
|
||||
termination_pred = self.model.termination(zs[1:], task, sigmoid=False)
|
||||
|
||||
# Compute losses
|
||||
reward_loss, value_loss = 0, 0
|
||||
@@ -291,7 +291,12 @@ class TDMPC2(torch.nn.Module):
|
||||
|
||||
consistency_loss = consistency_loss / self.cfg.horizon
|
||||
reward_loss = reward_loss / self.cfg.horizon
|
||||
termination_loss = F.binary_cross_entropy(termination_pred, terminated[-1])
|
||||
# termination_loss = F.binary_cross_entropy(termination_pred, terminated)
|
||||
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
|
||||
# termination_loss = F.binary_cross_entropy(termination_pred, terminated, reduction='none')
|
||||
# weighted mean over time, with last time step weighted as much as the rest combined
|
||||
# termination_loss[:-1] = termination_loss[:-1] / (self.cfg.horizon**2)
|
||||
# termination_loss = termination_loss.mean()
|
||||
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss +
|
||||
|
||||
Reference in New Issue
Block a user