faster offline data loading
This commit is contained in:
@@ -65,6 +65,30 @@ class Buffer():
|
|||||||
LazyTensorStorage(self._capacity, device=self._storage_device)
|
LazyTensorStorage(self._capacity, device=self._storage_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load(self, td):
|
||||||
|
"""
|
||||||
|
Load a batch of episodes into the buffer. This is useful for loading data from disk,
|
||||||
|
and is more efficient than adding episodes one by one.
|
||||||
|
"""
|
||||||
|
num_new_eps = len(td)
|
||||||
|
episode_idx = torch.arange(self._num_eps, self._num_eps+num_new_eps, dtype=torch.int64)
|
||||||
|
td['episode'] = episode_idx.unsqueeze(-1).expand(-1, td['reward'].shape[1])
|
||||||
|
if self._num_eps == 0:
|
||||||
|
self._buffer = self._init(td[0])
|
||||||
|
td = td.reshape(td.shape[0]*td.shape[1])
|
||||||
|
self._buffer.extend(td)
|
||||||
|
self._num_eps += num_new_eps
|
||||||
|
return self._num_eps
|
||||||
|
|
||||||
|
def add(self, td):
|
||||||
|
"""Add an episode to the buffer."""
|
||||||
|
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
|
||||||
|
if self._num_eps == 0:
|
||||||
|
self._buffer = self._init(td)
|
||||||
|
self._buffer.extend(td)
|
||||||
|
self._num_eps += 1
|
||||||
|
return self._num_eps
|
||||||
|
|
||||||
def _prepare_batch(self, td):
|
def _prepare_batch(self, td):
|
||||||
"""
|
"""
|
||||||
Prepare a sampled batch for training (post-processing).
|
Prepare a sampled batch for training (post-processing).
|
||||||
@@ -79,15 +103,6 @@ class Buffer():
|
|||||||
task = task[0].contiguous()
|
task = task[0].contiguous()
|
||||||
return obs, action, reward, task
|
return obs, action, reward, task
|
||||||
|
|
||||||
def add(self, td):
|
|
||||||
"""Add an episode to the buffer."""
|
|
||||||
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
|
|
||||||
if self._num_eps == 0:
|
|
||||||
self._buffer = self._init(td)
|
|
||||||
self._buffer.extend(td)
|
|
||||||
self._num_eps += 1
|
|
||||||
return self._num_eps
|
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
"""Sample a batch of subsequences from the buffer."""
|
"""Sample a batch of subsequences from the buffer."""
|
||||||
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class WorldModel(nn.Module):
|
|||||||
eps = torch.randn_like(mean)
|
eps = torch.randn_like(mean)
|
||||||
|
|
||||||
if self.cfg.multitask: # Mask out unused action dimensions
|
if self.cfg.multitask: # Mask out unused action dimensions
|
||||||
mu = mu * self._action_masks[task]
|
mean = mean * self._action_masks[task]
|
||||||
log_std = log_std * self._action_masks[task]
|
log_std = log_std * self._action_masks[task]
|
||||||
eps = eps * self._action_masks[task]
|
eps = eps * self._action_masks[task]
|
||||||
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
|
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
|
||||||
|
|||||||
@@ -45,8 +45,8 @@ class OfflineTrainer(Trainer):
|
|||||||
fps = sorted(glob(str(fp)))
|
fps = sorted(glob(str(fp)))
|
||||||
assert len(fps) > 0, f'No data found at {fp}'
|
assert len(fps) > 0, f'No data found at {fp}'
|
||||||
print(f'Found {len(fps)} files in {fp}')
|
print(f'Found {len(fps)} files in {fp}')
|
||||||
assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \
|
if len(fps) < (20 if self.cfg.task == 'mt80' else 4):
|
||||||
f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.'
|
print(f'WARNING: expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.')
|
||||||
|
|
||||||
# Create buffer for sampling
|
# Create buffer for sampling
|
||||||
_cfg = deepcopy(self.cfg)
|
_cfg = deepcopy(self.cfg)
|
||||||
@@ -59,11 +59,10 @@ class OfflineTrainer(Trainer):
|
|||||||
assert td.shape[1] == _cfg.episode_length, \
|
assert td.shape[1] == _cfg.episode_length, \
|
||||||
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
||||||
f'please double-check your config.'
|
f'please double-check your config.'
|
||||||
for i in range(len(td)):
|
self.buffer.load(td)
|
||||||
self.buffer.add(td[i])
|
|
||||||
expected_episodes = _cfg.buffer_size // _cfg.episode_length
|
expected_episodes = _cfg.buffer_size // _cfg.episode_length
|
||||||
assert self.buffer.num_eps == expected_episodes, \
|
if self.buffer.num_eps != expected_episodes:
|
||||||
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
|
print(f'WARNING: buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes for {self.cfg.task} task set.')
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Train a TD-MPC2 agent."""
|
"""Train a TD-MPC2 agent."""
|
||||||
|
|||||||
Reference in New Issue
Block a user