diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 84e49e1..c23b5f8 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -20,6 +20,7 @@ class Buffer(): traj_key='episode', truncated_key=None, strict_length=True, + cache_values=cfg.multitask, ) self._batch_size = cfg.batch_size * (cfg.horizon+1) self._num_eps = 0