This commit is contained in:
Nicklas Hansen
2023-12-25 10:11:42 -08:00
parent ca4dfa1db3
commit 2f86a1e4d8
2 changed files with 7 additions and 4 deletions

View File

@@ -233,7 +233,9 @@ class SliceSampler(Sampler):
and self._used_traj_key[0] == "_data" and self._used_traj_key[0] == "_data"
) )
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)]) vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
return self._cache.setdefault("stop-and-length", vals) if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError: except KeyError:
if fallback: if fallback:
self._fetch_traj = False self._fetch_traj = False
@@ -257,7 +259,9 @@ class SliceSampler(Sampler):
and self._used_end_key[0] == "_data" and self._used_end_key[0] == "_data"
) )
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
return self._cache.setdefault("stop-and-length", vals) if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError: except KeyError:
if fallback: if fallback:
self._fetch_traj = True self._fetch_traj = True

View File

@@ -51,8 +51,7 @@ def train(cfg: dict):
cfg=cfg, cfg=cfg,
env=make_env(cfg), env=make_env(cfg),
agent=TDMPC2(cfg), agent=TDMPC2(cfg),
buffer=CropBuffer(cfg), buffer=SliceBuffer(cfg),
# buffer=SliceBuffer(cfg),
logger=Logger(cfg), logger=Logger(cfg),
) )
trainer.train() trainer.train()