faster offline data loading

This commit is contained in:
Nicklas Hansen
2024-12-19 06:52:31 -08:00
parent df8a465c8e
commit 9cac7c5775
3 changed files with 30 additions and 16 deletions

View File

@@ -65,6 +65,30 @@ class Buffer():
LazyTensorStorage(self._capacity, device=self._storage_device)
)
def load(self, td):
"""
Load a batch of episodes into the buffer. This is useful for loading data from disk,
and is more efficient than adding episodes one by one.
"""
num_new_eps = len(td)
episode_idx = torch.arange(self._num_eps, self._num_eps+num_new_eps, dtype=torch.int64)
td['episode'] = episode_idx.unsqueeze(-1).expand(-1, td['reward'].shape[1])
if self._num_eps == 0:
self._buffer = self._init(td[0])
td = td.reshape(td.shape[0]*td.shape[1])
self._buffer.extend(td)
self._num_eps += num_new_eps
return self._num_eps
def add(self, td):
"""Add an episode to the buffer."""
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
if self._num_eps == 0:
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_eps += 1
return self._num_eps
def _prepare_batch(self, td):
"""
Prepare a sampled batch for training (post-processing).
@@ -79,15 +103,6 @@ class Buffer():
task = task[0].contiguous()
return obs, action, reward, task
def add(self, td):
"""Add an episode to the buffer."""
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
if self._num_eps == 0:
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of subsequences from the buffer."""
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)

View File

@@ -137,7 +137,7 @@ class WorldModel(nn.Module):
eps = torch.randn_like(mean)
if self.cfg.multitask: # Mask out unused action dimensions
mu = mu * self._action_masks[task]
mean = mean * self._action_masks[task]
log_std = log_std * self._action_masks[task]
eps = eps * self._action_masks[task]
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)

View File

@@ -45,8 +45,8 @@ class OfflineTrainer(Trainer):
fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}'
print(f'Found {len(fps)} files in {fp}')
assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \
f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.'
if len(fps) < (20 if self.cfg.task == 'mt80' else 4):
print(f'WARNING: expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.')
# Create buffer for sampling
_cfg = deepcopy(self.cfg)
@@ -59,11 +59,10 @@ class OfflineTrainer(Trainer):
assert td.shape[1] == _cfg.episode_length, \
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
f'please double-check your config.'
for i in range(len(td)):
self.buffer.add(td[i])
self.buffer.load(td)
expected_episodes = _cfg.buffer_size // _cfg.episode_length
assert self.buffer.num_eps == expected_episodes, \
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
if self.buffer.num_eps != expected_episodes:
print(f'WARNING: buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes for {self.cfg.task} task set.')
def train(self):
"""Train a TD-MPC2 agent."""