fix sampler https://github.com/pytorch/rl/pull/1762
This commit is contained in:
@@ -233,7 +233,9 @@ class SliceSampler(Sampler):
|
|||||||
and self._used_traj_key[0] == "_data"
|
and self._used_traj_key[0] == "_data"
|
||||||
)
|
)
|
||||||
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
|
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
|
||||||
return self._cache.setdefault("stop-and-length", vals)
|
if self.cache_values:
|
||||||
|
self._cache["stop-and-length"] = vals
|
||||||
|
return vals
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if fallback:
|
if fallback:
|
||||||
self._fetch_traj = False
|
self._fetch_traj = False
|
||||||
@@ -257,7 +259,9 @@ class SliceSampler(Sampler):
|
|||||||
and self._used_end_key[0] == "_data"
|
and self._used_end_key[0] == "_data"
|
||||||
)
|
)
|
||||||
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
|
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
|
||||||
return self._cache.setdefault("stop-and-length", vals)
|
if self.cache_values:
|
||||||
|
self._cache["stop-and-length"] = vals
|
||||||
|
return vals
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if fallback:
|
if fallback:
|
||||||
self._fetch_traj = True
|
self._fetch_traj = True
|
||||||
|
|||||||
@@ -51,8 +51,7 @@ def train(cfg: dict):
|
|||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
env=make_env(cfg),
|
env=make_env(cfg),
|
||||||
agent=TDMPC2(cfg),
|
agent=TDMPC2(cfg),
|
||||||
buffer=CropBuffer(cfg),
|
buffer=SliceBuffer(cfg),
|
||||||
# buffer=SliceBuffer(cfg),
|
|
||||||
logger=Logger(cfg),
|
logger=Logger(cfg),
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user