From 9cac7c57759242db83754f0aec2c01fc8dcd340e Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Thu, 19 Dec 2024 06:52:31 -0800 Subject: [PATCH] faster offline data loading --- tdmpc2/common/buffer.py | 33 ++++++++++++++++++++++--------- tdmpc2/common/world_model.py | 2 +- tdmpc2/trainer/offline_trainer.py | 11 +++++------ 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 3ff5b28..84e49e1 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -65,6 +65,30 @@ class Buffer(): LazyTensorStorage(self._capacity, device=self._storage_device) ) + def load(self, td): + """ + Load a batch of episodes into the buffer. This is useful for loading data from disk, + and is more efficient than adding episodes one by one. + """ + num_new_eps = len(td) + episode_idx = torch.arange(self._num_eps, self._num_eps+num_new_eps, dtype=torch.int64) + td['episode'] = episode_idx.unsqueeze(-1).expand(-1, td['reward'].shape[1]) + if self._num_eps == 0: + self._buffer = self._init(td[0]) + td = td.reshape(td.shape[0]*td.shape[1]) + self._buffer.extend(td) + self._num_eps += num_new_eps + return self._num_eps + + def add(self, td): + """Add an episode to the buffer.""" + td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64) + if self._num_eps == 0: + self._buffer = self._init(td) + self._buffer.extend(td) + self._num_eps += 1 + return self._num_eps + def _prepare_batch(self, td): """ Prepare a sampled batch for training (post-processing). @@ -79,15 +103,6 @@ class Buffer(): task = task[0].contiguous() return obs, action, reward, task - def add(self, td): - """Add an episode to the buffer.""" - td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64) - if self._num_eps == 0: - self._buffer = self._init(td) - self._buffer.extend(td) - self._num_eps += 1 - return self._num_eps - def sample(self): """Sample a batch of subsequences from the buffer.""" td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 8222b99..23b5d40 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -137,7 +137,7 @@ class WorldModel(nn.Module): eps = torch.randn_like(mean) if self.cfg.multitask: # Mask out unused action dimensions - mu = mu * self._action_masks[task] + mean = mean * self._action_masks[task] log_std = log_std * self._action_masks[task] eps = eps * self._action_masks[task] action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1) diff --git a/tdmpc2/trainer/offline_trainer.py b/tdmpc2/trainer/offline_trainer.py index 0d22fe8..a46d00b 100755 --- a/tdmpc2/trainer/offline_trainer.py +++ b/tdmpc2/trainer/offline_trainer.py @@ -45,8 +45,8 @@ class OfflineTrainer(Trainer): fps = sorted(glob(str(fp))) assert len(fps) > 0, f'No data found at {fp}' print(f'Found {len(fps)} files in {fp}') - assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \ - f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.' + if len(fps) < (20 if self.cfg.task == 'mt80' else 4): + print(f'WARNING: expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.') # Create buffer for sampling _cfg = deepcopy(self.cfg) @@ -59,11 +59,10 @@ class OfflineTrainer(Trainer): assert td.shape[1] == _cfg.episode_length, \ f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \ f'please double-check your config.' - for i in range(len(td)): - self.buffer.add(td[i]) + self.buffer.load(td) expected_episodes = _cfg.buffer_size // _cfg.episode_length - assert self.buffer.num_eps == expected_episodes, \ - f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.' + if self.buffer.num_eps != expected_episodes: + print(f'WARNING: buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes for {self.cfg.task} task set.') def train(self): """Train a TD-MPC2 agent."""