further reduce buffer differences
This commit is contained in:
@@ -48,6 +48,7 @@ class Buffer():
|
||||
|
||||
def _init(self, tds):
|
||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||
print('Buffer capacity:', self._capacity)
|
||||
mem_free, _ = torch.cuda.mem_get_info()
|
||||
bytes_per_ep = sum([
|
||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||
@@ -81,9 +82,18 @@ class Buffer():
|
||||
task = td['task'][0] if 'task' in td.keys() else None
|
||||
return self._to_device(obs, action, reward, task)
|
||||
|
||||
def _add(self, td):
|
||||
"""Internal function that adds episode to the buffer."""
|
||||
pass
|
||||
|
||||
def add(self, td):
|
||||
"""Add an episode to the buffer."""
|
||||
pass
|
||||
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
|
||||
if self._num_eps == 0:
|
||||
self._buffer = self._init(td)
|
||||
self._add(td)
|
||||
self._num_eps += 1
|
||||
return self._num_eps
|
||||
|
||||
def sample(self):
|
||||
"""Sample a batch of sub-trajectories from the buffer."""
|
||||
@@ -103,13 +113,9 @@ class CropBuffer(Buffer):
|
||||
self._transform = RandomCropTensorDict(cfg.horizon+1, -1)
|
||||
self._batch_size = cfg.batch_size
|
||||
|
||||
def add(self, td):
|
||||
"""Add an episode to the buffer. All episodes are expected to be equal length."""
|
||||
if self._num_eps == 0:
|
||||
self._buffer = self._init(td)
|
||||
def _add(self, td):
|
||||
"""Add an episode to the buffer, with trajectories as the leading dimension."""
|
||||
self._buffer.add(td)
|
||||
self._num_eps += 1
|
||||
return self._num_eps
|
||||
|
||||
def sample(self):
|
||||
"""Sample a batch of subsequences from the buffer."""
|
||||
@@ -134,14 +140,9 @@ class SliceBuffer(Buffer):
|
||||
)
|
||||
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||
|
||||
def add(self, td):
|
||||
"""Add an episode to the buffer. Supports variable episode lengths."""
|
||||
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
|
||||
if self._num_eps == 0:
|
||||
self._buffer = self._init(td)
|
||||
def _add(self, td):
|
||||
"""Add an episode to the buffer, with transitions as the leading dimension."""
|
||||
self._buffer.extend(td)
|
||||
self._num_eps += 1
|
||||
return self._num_eps
|
||||
|
||||
def sample(self):
|
||||
"""Sample a batch of subsequences from the buffer."""
|
||||
|
||||
@@ -10,19 +10,19 @@ from common.buffer import CropBuffer, SliceBuffer
|
||||
|
||||
@hydra.main(config_name='config', config_path='.')
|
||||
def test_buffer(cfg: dict):
|
||||
cfg.episode_length = 12
|
||||
cfg.episode_length = 11
|
||||
cfg.batch_size = 8
|
||||
|
||||
transitions0 = [TensorDict(dict(
|
||||
obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t,
|
||||
action=torch.tensor([-1.]) ** t,
|
||||
action=torch.tensor([-1.]).unsqueeze(0) ** t,
|
||||
reward=torch.tensor([1.]) * t,
|
||||
), batch_size=(1,)) for t in range(cfg.episode_length)]
|
||||
episode0 = torch.cat(transitions0)
|
||||
|
||||
transitions1 = [TensorDict(dict(
|
||||
obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t,
|
||||
action=(torch.tensor([-1.]) ** t) * 0.5,
|
||||
action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5,
|
||||
reward=torch.tensor([-1.]) * t,
|
||||
), batch_size=(1,)) for t in range(cfg.episode_length)]
|
||||
episode1 = torch.cat(transitions1)
|
||||
|
||||
Reference in New Issue
Block a user