This commit is contained in:
Nicklas Hansen
2023-12-25 10:11:42 -08:00
parent ca4dfa1db3
commit 2f86a1e4d8
2 changed files with 7 additions and 4 deletions

View File

@@ -233,7 +233,9 @@ class SliceSampler(Sampler):
and self._used_traj_key[0] == "_data"
)
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
return self._cache.setdefault("stop-and-length", vals)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = False
@@ -257,7 +259,9 @@ class SliceSampler(Sampler):
and self._used_end_key[0] == "_data"
)
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
return self._cache.setdefault("stop-and-length", vals)
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = True

View File

@@ -51,8 +51,7 @@ def train(cfg: dict):
cfg=cfg,
env=make_env(cfg),
agent=TDMPC2(cfg),
buffer=CropBuffer(cfg),
# buffer=SliceBuffer(cfg),
buffer=SliceBuffer(cfg),
logger=Logger(cfg),
)
trainer.train()