From 70fe242adc23cc19d6f6e466bcc423f7f6523c21 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Fri, 22 Dec 2023 14:26:48 -0800 Subject: [PATCH] does not reproduce results w/ previous buffer --- tdmpc2/common/buffer.py | 4 ++-- tdmpc2/trainer/online_trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 7010549..8d914a7 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -37,13 +37,13 @@ class Buffer(): return ReplayBuffer( storage=storage, sampler=SliceSampler( - slice_len=self.cfg.horizon+1, + num_slices=self.cfg.batch_size, end_key=None, traj_key='episode', truncated_key=None, ), pin_memory=True, - prefetch=2, + prefetch=1, batch_size=self.cfg.batch_size, ) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 94835ca..f5f65cc 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -50,7 +50,7 @@ class OnlineTrainer(Trainer): def to_td(self, obs, action=None, reward=None): """Creates a TensorDict for a new episode.""" if isinstance(obs, dict): - obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu() + obs = TensorDict(obs, batch_size=(), device='cpu') else: obs = obs.unsqueeze(0).cpu() if action is None: