fix sampler https://github.com/pytorch/rl/pull/1762
This commit is contained in:
@@ -233,7 +233,9 @@ class SliceSampler(Sampler):
|
||||
and self._used_traj_key[0] == "_data"
|
||||
)
|
||||
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
|
||||
return self._cache.setdefault("stop-and-length", vals)
|
||||
if self.cache_values:
|
||||
self._cache["stop-and-length"] = vals
|
||||
return vals
|
||||
except KeyError:
|
||||
if fallback:
|
||||
self._fetch_traj = False
|
||||
@@ -257,7 +259,9 @@ class SliceSampler(Sampler):
|
||||
and self._used_end_key[0] == "_data"
|
||||
)
|
||||
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
|
||||
return self._cache.setdefault("stop-and-length", vals)
|
||||
if self.cache_values:
|
||||
self._cache["stop-and-length"] = vals
|
||||
return vals
|
||||
except KeyError:
|
||||
if fallback:
|
||||
self._fetch_traj = True
|
||||
|
||||
@@ -51,8 +51,7 @@ def train(cfg: dict):
|
||||
cfg=cfg,
|
||||
env=make_env(cfg),
|
||||
agent=TDMPC2(cfg),
|
||||
buffer=CropBuffer(cfg),
|
||||
# buffer=SliceBuffer(cfg),
|
||||
buffer=SliceBuffer(cfg),
|
||||
logger=Logger(cfg),
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user