diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 1b9330c..75d17d2 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -48,6 +48,7 @@ class Buffer(): def _init(self, tds): """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" + print('Buffer capacity:', self._capacity) mem_free, _ = torch.cuda.mem_get_info() bytes_per_ep = sum([ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ @@ -81,9 +82,18 @@ class Buffer(): task = td['task'][0] if 'task' in td.keys() else None return self._to_device(obs, action, reward, task) + def _add(self, td): + """Internal function that adds episode to the buffer.""" + pass + def add(self, td): """Add an episode to the buffer.""" - pass + td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps + if self._num_eps == 0: + self._buffer = self._init(td) + self._add(td) + self._num_eps += 1 + return self._num_eps def sample(self): """Sample a batch of sub-trajectories from the buffer.""" @@ -103,13 +113,9 @@ class CropBuffer(Buffer): self._transform = RandomCropTensorDict(cfg.horizon+1, -1) self._batch_size = cfg.batch_size - def add(self, td): - """Add an episode to the buffer. All episodes are expected to be equal length.""" - if self._num_eps == 0: - self._buffer = self._init(td) + def _add(self, td): + """Add an episode to the buffer, with trajectories as the leading dimension.""" self._buffer.add(td) - self._num_eps += 1 - return self._num_eps def sample(self): """Sample a batch of subsequences from the buffer.""" @@ -134,14 +140,9 @@ class SliceBuffer(Buffer): ) self._batch_size = cfg.batch_size * (cfg.horizon+1) - def add(self, td): - """Add an episode to the buffer. Supports variable episode lengths.""" - td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps - if self._num_eps == 0: - self._buffer = self._init(td) + def _add(self, td): + """Add an episode to the buffer, with transitions as the leading dimension.""" self._buffer.extend(td) - self._num_eps += 1 - return self._num_eps def sample(self): """Sample a batch of subsequences from the buffer.""" diff --git a/tdmpc2/test_buffer.py b/tdmpc2/test_buffer.py index 7a58d08..87e9212 100644 --- a/tdmpc2/test_buffer.py +++ b/tdmpc2/test_buffer.py @@ -10,19 +10,19 @@ from common.buffer import CropBuffer, SliceBuffer @hydra.main(config_name='config', config_path='.') def test_buffer(cfg: dict): - cfg.episode_length = 12 + cfg.episode_length = 11 cfg.batch_size = 8 transitions0 = [TensorDict(dict( obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t, - action=torch.tensor([-1.]) ** t, + action=torch.tensor([-1.]).unsqueeze(0) ** t, reward=torch.tensor([1.]) * t, ), batch_size=(1,)) for t in range(cfg.episode_length)] episode0 = torch.cat(transitions0) transitions1 = [TensorDict(dict( obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t, - action=(torch.tensor([-1.]) ** t) * 0.5, + action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5, reward=torch.tensor([-1.]) * t, ), batch_size=(1,)) for t in range(cfg.episode_length)] episode1 = torch.cat(transitions1)