does not reproduce results w/ previous buffer

This commit is contained in:
Nicklas Hansen
2023-12-22 14:26:48 -08:00
parent 2929cfdb44
commit 70fe242adc
2 changed files with 3 additions and 3 deletions

View File

@@ -37,13 +37,13 @@ class Buffer():
return ReplayBuffer( return ReplayBuffer(
storage=storage, storage=storage,
sampler=SliceSampler( sampler=SliceSampler(
slice_len=self.cfg.horizon+1, num_slices=self.cfg.batch_size,
end_key=None, end_key=None,
traj_key='episode', traj_key='episode',
truncated_key=None, truncated_key=None,
), ),
pin_memory=True, pin_memory=True,
prefetch=2, prefetch=1,
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
) )

View File

@@ -50,7 +50,7 @@ class OnlineTrainer(Trainer):
def to_td(self, obs, action=None, reward=None): def to_td(self, obs, action=None, reward=None):
"""Creates a TensorDict for a new episode.""" """Creates a TensorDict for a new episode."""
if isinstance(obs, dict): if isinstance(obs, dict):
obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu() obs = TensorDict(obs, batch_size=(), device='cpu')
else: else:
obs = obs.unsqueeze(0).cpu() obs = obs.unsqueeze(0).cpu()
if action is None: if action is None: