only instantiate termination pred head if episodic=true

This commit is contained in:
Nicklas Hansen
2025-05-02 16:51:24 -07:00
parent 38b31a5d72
commit 7ec6bc83a8
2 changed files with 7 additions and 2 deletions

View File

@@ -99,7 +99,11 @@ class Buffer():
obs = td.get('obs').contiguous()
action = td.get('action')[1:].contiguous()
reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
terminated = td.get('terminated', None)
if terminated is not None:
terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
else:
terminated = torch.zeros_like(reward)
task = td.get('task', None)
if task is not None:
task = task[0].contiguous()

View File

@@ -280,7 +280,8 @@ class TDMPC2(torch.nn.Module):
_zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task)
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
if self.cfg.episodic:
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
# Compute losses
reward_loss, value_loss = 0, 0