only instantiate termination pred head if episodic=true
This commit is contained in:
@@ -99,7 +99,11 @@ class Buffer():
|
|||||||
obs = td.get('obs').contiguous()
|
obs = td.get('obs').contiguous()
|
||||||
action = td.get('action')[1:].contiguous()
|
action = td.get('action')[1:].contiguous()
|
||||||
reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
|
reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
|
||||||
|
terminated = td.get('terminated', None)
|
||||||
|
if terminated is not None:
|
||||||
terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
|
terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
|
||||||
|
else:
|
||||||
|
terminated = torch.zeros_like(reward)
|
||||||
task = td.get('task', None)
|
task = td.get('task', None)
|
||||||
if task is not None:
|
if task is not None:
|
||||||
task = task[0].contiguous()
|
task = task[0].contiguous()
|
||||||
|
|||||||
@@ -280,6 +280,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
_zs = zs[:-1]
|
_zs = zs[:-1]
|
||||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||||
reward_preds = self.model.reward(_zs, action, task)
|
reward_preds = self.model.reward(_zs, action, task)
|
||||||
|
if self.cfg.episodic:
|
||||||
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
|
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
|
|||||||
Reference in New Issue
Block a user