diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 9b74c70..4b16f6f 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -99,7 +99,11 @@ class Buffer(): obs = td.get('obs').contiguous() action = td.get('action')[1:].contiguous() reward = td.get('reward')[1:].unsqueeze(-1).contiguous() - terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous() + terminated = td.get('terminated', None) + if terminated is not None: + terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous() + else: + terminated = torch.zeros_like(reward) task = td.get('task', None) if task is not None: task = task[0].contiguous() diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index bdb5225..7111329 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -280,7 +280,8 @@ class TDMPC2(torch.nn.Module): _zs = zs[:-1] qs = self.model.Q(_zs, action, task, return_type='all') reward_preds = self.model.reward(_zs, action, task) - termination_pred = self.model.termination(zs[1:], task, unnormalized=True) + if self.cfg.episodic: + termination_pred = self.model.termination(zs[1:], task, unnormalized=True) # Compute losses reward_loss, value_loss = 0, 0