amend slicesampler indexing

This commit is contained in:
Nicklas Hansen
2024-01-08 10:48:04 -08:00
parent 31249a8961
commit 0f3bc77011
2 changed files with 2 additions and 5 deletions

View File

@@ -300,7 +300,7 @@ class SliceSampler(Sampler):
relative_starts = ( relative_starts = (
( (
torch.rand(num_slices, device=lengths.device) torch.rand(num_slices, device=lengths.device)
* (lengths[traj_idx] - seq_length) * (lengths[traj_idx] - seq_length + 1)
) )
.floor() .floor()
.to(start_idx.dtype) .to(start_idx.dtype)

View File

@@ -47,10 +47,8 @@ class OnlineTrainer(Trainer):
episode_success=np.nanmean(ep_successes), episode_success=np.nanmean(ep_successes),
) )
def to_td(self, obs=None, action=None, reward=None): def to_td(self, obs, action=None, reward=None):
"""Creates a TensorDict for a new episode.""" """Creates a TensorDict for a new episode."""
if obs is None:
obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan'))
if isinstance(obs, dict): if isinstance(obs, dict):
obs = TensorDict(obs, batch_size=(), device='cpu') obs = TensorDict(obs, batch_size=(), device='cpu')
else: else:
@@ -90,7 +88,6 @@ class OnlineTrainer(Trainer):
) )
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._tds.append(self.to_td()) # Separate episodes with NaNs
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset() obs = self.env.reset()