does not reproduce results w/ previous buffer
This commit is contained in:
@@ -37,13 +37,13 @@ class Buffer():
|
|||||||
return ReplayBuffer(
|
return ReplayBuffer(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
sampler=SliceSampler(
|
sampler=SliceSampler(
|
||||||
slice_len=self.cfg.horizon+1,
|
num_slices=self.cfg.batch_size,
|
||||||
end_key=None,
|
end_key=None,
|
||||||
traj_key='episode',
|
traj_key='episode',
|
||||||
truncated_key=None,
|
truncated_key=None,
|
||||||
),
|
),
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
prefetch=2,
|
prefetch=1,
|
||||||
batch_size=self.cfg.batch_size,
|
batch_size=self.cfg.batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class OnlineTrainer(Trainer):
|
|||||||
def to_td(self, obs, action=None, reward=None):
|
def to_td(self, obs, action=None, reward=None):
|
||||||
"""Creates a TensorDict for a new episode."""
|
"""Creates a TensorDict for a new episode."""
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu()
|
obs = TensorDict(obs, batch_size=(), device='cpu')
|
||||||
else:
|
else:
|
||||||
obs = obs.unsqueeze(0).cpu()
|
obs = obs.unsqueeze(0).cpu()
|
||||||
if action is None:
|
if action is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user