only instantiate termination pred head if episodic=true
This commit is contained in:
@@ -99,7 +99,11 @@ class Buffer():
|
||||
obs = td.get('obs').contiguous()
|
||||
action = td.get('action')[1:].contiguous()
|
||||
reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
|
||||
terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
|
||||
terminated = td.get('terminated', None)
|
||||
if terminated is not None:
|
||||
terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
|
||||
else:
|
||||
terminated = torch.zeros_like(reward)
|
||||
task = td.get('task', None)
|
||||
if task is not None:
|
||||
task = task[0].contiguous()
|
||||
|
||||
@@ -280,7 +280,8 @@ class TDMPC2(torch.nn.Module):
|
||||
_zs = zs[:-1]
|
||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||
reward_preds = self.model.reward(_zs, action, task)
|
||||
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
|
||||
if self.cfg.episodic:
|
||||
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
|
||||
|
||||
# Compute losses
|
||||
reward_loss, value_loss = 0, 0
|
||||
|
||||
Reference in New Issue
Block a user