From 31249a8961f5125d55acb01586e80f1a74d2ee73 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 7 Jan 2024 14:21:38 -0800 Subject: [PATCH] separate episodes with nans --- tdmpc2/trainer/online_trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index ca33009..f1ee97e 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -47,8 +47,10 @@ class OnlineTrainer(Trainer): episode_success=np.nanmean(ep_successes), ) - def to_td(self, obs, action=None, reward=None): + def to_td(self, obs=None, action=None, reward=None): """Creates a TensorDict for a new episode.""" + if obs is None: + obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan')) if isinstance(obs, dict): obs = TensorDict(obs, batch_size=(), device='cpu') else: @@ -62,7 +64,7 @@ class OnlineTrainer(Trainer): action=action.unsqueeze(0), reward=reward.unsqueeze(0), ), batch_size=(1,)) - return td + return td def train(self): """Train a TD-MPC2 agent.""" @@ -88,6 +90,7 @@ class OnlineTrainer(Trainer): ) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train') + self._tds.append(self.to_td()) # Separate episodes with NaNs self._ep_idx = self.buffer.add(torch.cat(self._tds)) obs = self.env.reset()