faster offline data loading
This commit is contained in:
@@ -65,6 +65,30 @@ class Buffer():
|
||||
LazyTensorStorage(self._capacity, device=self._storage_device)
|
||||
)
|
||||
|
||||
def load(self, td):
|
||||
"""
|
||||
Load a batch of episodes into the buffer. This is useful for loading data from disk,
|
||||
and is more efficient than adding episodes one by one.
|
||||
"""
|
||||
num_new_eps = len(td)
|
||||
episode_idx = torch.arange(self._num_eps, self._num_eps+num_new_eps, dtype=torch.int64)
|
||||
td['episode'] = episode_idx.unsqueeze(-1).expand(-1, td['reward'].shape[1])
|
||||
if self._num_eps == 0:
|
||||
self._buffer = self._init(td[0])
|
||||
td = td.reshape(td.shape[0]*td.shape[1])
|
||||
self._buffer.extend(td)
|
||||
self._num_eps += num_new_eps
|
||||
return self._num_eps
|
||||
|
||||
def add(self, td):
|
||||
"""Add an episode to the buffer."""
|
||||
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
|
||||
if self._num_eps == 0:
|
||||
self._buffer = self._init(td)
|
||||
self._buffer.extend(td)
|
||||
self._num_eps += 1
|
||||
return self._num_eps
|
||||
|
||||
def _prepare_batch(self, td):
|
||||
"""
|
||||
Prepare a sampled batch for training (post-processing).
|
||||
@@ -79,15 +103,6 @@ class Buffer():
|
||||
task = task[0].contiguous()
|
||||
return obs, action, reward, task
|
||||
|
||||
def add(self, td):
|
||||
"""Add an episode to the buffer."""
|
||||
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
|
||||
if self._num_eps == 0:
|
||||
self._buffer = self._init(td)
|
||||
self._buffer.extend(td)
|
||||
self._num_eps += 1
|
||||
return self._num_eps
|
||||
|
||||
def sample(self):
|
||||
"""Sample a batch of subsequences from the buffer."""
|
||||
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
||||
|
||||
@@ -137,7 +137,7 @@ class WorldModel(nn.Module):
|
||||
eps = torch.randn_like(mean)
|
||||
|
||||
if self.cfg.multitask: # Mask out unused action dimensions
|
||||
mu = mu * self._action_masks[task]
|
||||
mean = mean * self._action_masks[task]
|
||||
log_std = log_std * self._action_masks[task]
|
||||
eps = eps * self._action_masks[task]
|
||||
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
|
||||
|
||||
@@ -45,8 +45,8 @@ class OfflineTrainer(Trainer):
|
||||
fps = sorted(glob(str(fp)))
|
||||
assert len(fps) > 0, f'No data found at {fp}'
|
||||
print(f'Found {len(fps)} files in {fp}')
|
||||
assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \
|
||||
f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.'
|
||||
if len(fps) < (20 if self.cfg.task == 'mt80' else 4):
|
||||
print(f'WARNING: expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.')
|
||||
|
||||
# Create buffer for sampling
|
||||
_cfg = deepcopy(self.cfg)
|
||||
@@ -59,11 +59,10 @@ class OfflineTrainer(Trainer):
|
||||
assert td.shape[1] == _cfg.episode_length, \
|
||||
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
||||
f'please double-check your config.'
|
||||
for i in range(len(td)):
|
||||
self.buffer.add(td[i])
|
||||
self.buffer.load(td)
|
||||
expected_episodes = _cfg.buffer_size // _cfg.episode_length
|
||||
assert self.buffer.num_eps == expected_episodes, \
|
||||
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
|
||||
if self.buffer.num_eps != expected_episodes:
|
||||
print(f'WARNING: buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes for {self.cfg.task} task set.')
|
||||
|
||||
def train(self):
|
||||
"""Train a TD-MPC2 agent."""
|
||||
|
||||
Reference in New Issue
Block a user