diff --git a/tdmpc2/common/samplers.py b/tdmpc2/common/samplers.py index 71cf70c..61c9f40 100644 --- a/tdmpc2/common/samplers.py +++ b/tdmpc2/common/samplers.py @@ -300,7 +300,7 @@ class SliceSampler(Sampler): relative_starts = ( ( torch.rand(num_slices, device=lengths.device) - * (lengths[traj_idx] - seq_length) + * (lengths[traj_idx] - seq_length + 1) ) .floor() .to(start_idx.dtype) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index f1ee97e..a3326bc 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -47,10 +47,8 @@ class OnlineTrainer(Trainer): episode_success=np.nanmean(ep_successes), ) - def to_td(self, obs=None, action=None, reward=None): + def to_td(self, obs, action=None, reward=None): """Creates a TensorDict for a new episode.""" - if obs is None: - obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan')) if isinstance(obs, dict): obs = TensorDict(obs, batch_size=(), device='cpu') else: @@ -90,7 +88,6 @@ class OnlineTrainer(Trainer): ) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train') - self._tds.append(self.to_td()) # Separate episodes with NaNs self._ep_idx = self.buffer.add(torch.cat(self._tds)) obs = self.env.reset()