separate episodes with nans

This commit is contained in:
Nicklas Hansen
2024-01-07 14:21:38 -08:00
parent 33876d124f
commit 31249a8961

View File

@@ -47,8 +47,10 @@ class OnlineTrainer(Trainer):
episode_success=np.nanmean(ep_successes),
)
def to_td(self, obs, action=None, reward=None):
def to_td(self, obs=None, action=None, reward=None):
"""Creates a TensorDict for a new episode."""
if obs is None:
obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan'))
if isinstance(obs, dict):
obs = TensorDict(obs, batch_size=(), device='cpu')
else:
@@ -62,7 +64,7 @@ class OnlineTrainer(Trainer):
action=action.unsqueeze(0),
reward=reward.unsqueeze(0),
), batch_size=(1,))
return td
return td
def train(self):
"""Train a TD-MPC2 agent."""
@@ -88,6 +90,7 @@ class OnlineTrainer(Trainer):
)
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')
self._tds.append(self.to_td()) # Separate episodes with NaNs
self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset()