amend slicesampler indexing
This commit is contained in:
@@ -300,7 +300,7 @@ class SliceSampler(Sampler):
|
|||||||
relative_starts = (
|
relative_starts = (
|
||||||
(
|
(
|
||||||
torch.rand(num_slices, device=lengths.device)
|
torch.rand(num_slices, device=lengths.device)
|
||||||
* (lengths[traj_idx] - seq_length)
|
* (lengths[traj_idx] - seq_length + 1)
|
||||||
)
|
)
|
||||||
.floor()
|
.floor()
|
||||||
.to(start_idx.dtype)
|
.to(start_idx.dtype)
|
||||||
|
|||||||
@@ -47,10 +47,8 @@ class OnlineTrainer(Trainer):
|
|||||||
episode_success=np.nanmean(ep_successes),
|
episode_success=np.nanmean(ep_successes),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_td(self, obs=None, action=None, reward=None):
|
def to_td(self, obs, action=None, reward=None):
|
||||||
"""Creates a TensorDict for a new episode."""
|
"""Creates a TensorDict for a new episode."""
|
||||||
if obs is None:
|
|
||||||
obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan'))
|
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
obs = TensorDict(obs, batch_size=(), device='cpu')
|
obs = TensorDict(obs, batch_size=(), device='cpu')
|
||||||
else:
|
else:
|
||||||
@@ -90,7 +88,6 @@ class OnlineTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
train_metrics.update(self.common_metrics())
|
train_metrics.update(self.common_metrics())
|
||||||
self.logger.log(train_metrics, 'train')
|
self.logger.log(train_metrics, 'train')
|
||||||
self._tds.append(self.to_td()) # Separate episodes with NaNs
|
|
||||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||||
|
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
|
|||||||
Reference in New Issue
Block a user