further reduce buffer differences

This commit is contained in:
Nicklas Hansen
2023-12-23 09:39:13 -08:00
parent eef1d1b407
commit ca4dfa1db3
2 changed files with 18 additions and 17 deletions

View File

@@ -48,6 +48,7 @@ class Buffer():
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
print('Buffer capacity:', self._capacity)
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
@@ -81,9 +82,18 @@ class Buffer():
task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, task)
def _add(self, td):
"""Internal function that adds episode to the buffer."""
pass
def add(self, td):
"""Add an episode to the buffer."""
pass
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
if self._num_eps == 0:
self._buffer = self._init(td)
self._add(td)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
@@ -103,13 +113,9 @@ class CropBuffer(Buffer):
self._transform = RandomCropTensorDict(cfg.horizon+1, -1)
self._batch_size = cfg.batch_size
def add(self, td):
"""Add an episode to the buffer. All episodes are expected to be equal length."""
if self._num_eps == 0:
self._buffer = self._init(td)
def _add(self, td):
"""Add an episode to the buffer, with trajectories as the leading dimension."""
self._buffer.add(td)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of subsequences from the buffer."""
@@ -134,14 +140,9 @@ class SliceBuffer(Buffer):
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
def add(self, td):
"""Add an episode to the buffer. Supports variable episode lengths."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
if self._num_eps == 0:
self._buffer = self._init(td)
def _add(self, td):
"""Add an episode to the buffer, with transitions as the leading dimension."""
self._buffer.extend(td)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of subsequences from the buffer."""

View File

@@ -10,19 +10,19 @@ from common.buffer import CropBuffer, SliceBuffer
@hydra.main(config_name='config', config_path='.')
def test_buffer(cfg: dict):
cfg.episode_length = 12
cfg.episode_length = 11
cfg.batch_size = 8
transitions0 = [TensorDict(dict(
obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t,
action=torch.tensor([-1.]) ** t,
action=torch.tensor([-1.]).unsqueeze(0) ** t,
reward=torch.tensor([1.]) * t,
), batch_size=(1,)) for t in range(cfg.episode_length)]
episode0 = torch.cat(transitions0)
transitions1 = [TensorDict(dict(
obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t,
action=(torch.tensor([-1.]) ** t) * 0.5,
action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5,
reward=torch.tensor([-1.]) * t,
), batch_size=(1,)) for t in range(cfg.episode_length)]
episode1 = torch.cat(transitions1)