From 2f86a1e4d8f49abdf49acc1b65f77907abf7e0e9 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Mon, 25 Dec 2023 10:11:42 -0800 Subject: [PATCH] fix sampler https://github.com/pytorch/rl/pull/1762 --- tdmpc2/common/samplers.py | 8 ++++++-- tdmpc2/train.py | 3 +-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tdmpc2/common/samplers.py b/tdmpc2/common/samplers.py index af0e073..a8cd72f 100644 --- a/tdmpc2/common/samplers.py +++ b/tdmpc2/common/samplers.py @@ -233,7 +233,9 @@ class SliceSampler(Sampler): and self._used_traj_key[0] == "_data" ) vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)]) - return self._cache.setdefault("stop-and-length", vals) + if self.cache_values: + self._cache["stop-and-length"] = vals + return vals except KeyError: if fallback: self._fetch_traj = False @@ -257,7 +259,9 @@ class SliceSampler(Sampler): and self._used_end_key[0] == "_data" ) vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] - return self._cache.setdefault("stop-and-length", vals) + if self.cache_values: + self._cache["stop-and-length"] = vals + return vals except KeyError: if fallback: self._fetch_traj = True diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 5303e09..b091bec 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -51,8 +51,7 @@ def train(cfg: dict): cfg=cfg, env=make_env(cfg), agent=TDMPC2(cfg), - buffer=CropBuffer(cfg), - # buffer=SliceBuffer(cfg), + buffer=SliceBuffer(cfg), logger=Logger(cfg), ) trainer.train()