diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 8d914a7..1b9330c 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -1,24 +1,27 @@ -from pathlib import Path import torch from tensordict.tensordict import TensorDict from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage +from torchrl.data.replay_buffers.samplers import RandomSampler +from torchrl.envs import RandomCropTensorDict -from common.logger import make_dir from common.samplers import SliceSampler class Buffer(): """ - Create a replay buffer for TD-MPC2 training. + Base class for TD-MPC2 replay buffers. Uses CUDA memory if available, and CPU memory otherwise. """ def __init__(self, cfg): self.cfg = cfg self._device = torch.device('cuda') - self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1) - self._capacity = min(cfg.buffer_size, cfg.steps) + self._capacity = None + self._max_eps = None self._num_eps = 0 + self._sampler = None + self._transform = None + self._batch_size = None @property def capacity(self): @@ -36,15 +39,11 @@ class Buffer(): """ return ReplayBuffer( storage=storage, - sampler=SliceSampler( - num_slices=self.cfg.batch_size, - end_key=None, - traj_key='episode', - truncated_key=None, - ), + sampler=self._sampler, pin_memory=True, prefetch=1, - batch_size=self.cfg.batch_size, + transform=self._transform, + batch_size=self._batch_size, ) def _init(self, tds): @@ -53,45 +52,98 @@ class Buffer(): bytes_per_ep = sum([ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ else sum([x.numel()*x.element_size() for x in v.values()])) \ - for v in tds.values() - ]) + for k,v in tds.items() + ]) print(f'Bytes per episode: {bytes_per_ep:,}') - total_bytes = bytes_per_ep * (self._capacity // self.cfg.episode_length) + total_bytes = bytes_per_ep*self._max_eps print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory - if 2.5*total_bytes > mem_free: # Insufficient CUDA memory - print('Using CPU memory for storage.') - return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device('cpu')) - ) - else: # Sufficient CUDA memory - print('Using CUDA memory for storage.') - return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device('cuda')) - ) + storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' + print(f'Using {storage_device.upper()} memory for storage.') + return self._reserve_buffer( + LazyTensorStorage(self._capacity, device=torch.device(storage_device)) + ) - def add(self, tds): - """Add a step to the buffer.""" - tds['episode'] = torch.ones_like(tds['reward'], dtype=torch.int64) * self._num_eps - tds['step'] = torch.arange(0, len(tds)) + def _to_device(self, *args, device=None): + if device is None: + device = self._device + return (arg.to(device, non_blocking=True) \ + if arg is not None else None for arg in args) + + def _prepare_batch(self, td): + """ + Prepare a sampled batch for training (post-processing). + Expects `td` to be a TensorDict with batch size TxB. + """ + obs = td['obs'] + action = td['action'][1:] + reward = td['reward'][1:].unsqueeze(-1) + task = td['task'][0] if 'task' in td.keys() else None + return self._to_device(obs, action, reward, task) + + def add(self, td): + """Add an episode to the buffer.""" + pass + + def sample(self): + """Sample a batch of sub-trajectories from the buffer.""" + pass + + +class CropBuffer(Buffer): + """ + A replay buffer that first samples trajectories, and then crops to desired length. + """ + + def __init__(self, cfg): + super().__init__(cfg) + self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length + self._max_eps = self._capacity + self._sampler = RandomSampler() + self._transform = RandomCropTensorDict(cfg.horizon+1, -1) + self._batch_size = cfg.batch_size + + def add(self, td): + """Add an episode to the buffer. All episodes are expected to be equal length.""" if self._num_eps == 0: - self._buffer = self._init(tds) - self._buffer.extend(tds) + self._buffer = self._init(td) + self._buffer.add(td) self._num_eps += 1 return self._num_eps def sample(self): - """Sample a batch of sub-trajectories from the buffer.""" - td = self._buffer.sample(batch_size=self._batch_size) \ - .view(-1, self.cfg.horizon+1).permute(1, 0) - obs = td['obs'].to(self._device, non_blocking=True) - action = td['action'][1:].to(self._device, non_blocking=True) - reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True) - task = td['task'][0].to(self._device, non_blocking=True) if 'task' in td.keys() else None - return obs, action, reward, task - - def save(self): - """Save the buffer to disk. Useful for storing offline datasets.""" - td = self._buffer._storage._storage.cpu() - fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt' - torch.save(td, fp) + """Sample a batch of subsequences from the buffer.""" + td = self._buffer.sample().permute(1,0) + return self._prepare_batch(td) + + +class SliceBuffer(Buffer): + """ + A replay buffer that directly samples subsequences. More efficient than CropBuffer. + """ + + def __init__(self, cfg): + super().__init__(cfg) + self._capacity = min(cfg.buffer_size, cfg.steps) + self._max_eps = self._capacity//cfg.episode_length + self._sampler = SliceSampler( + num_slices=self.cfg.batch_size, + end_key=None, + traj_key='episode', + truncated_key=None, + ) + self._batch_size = cfg.batch_size * (cfg.horizon+1) + + def add(self, td): + """Add an episode to the buffer. Supports variable episode lengths.""" + td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps + if self._num_eps == 0: + self._buffer = self._init(td) + self._buffer.extend(td) + self._num_eps += 1 + return self._num_eps + + def sample(self): + """Sample a batch of subsequences from the buffer.""" + td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0) + return self._prepare_batch(td) diff --git a/tdmpc2/common/legacy_buffer.py b/tdmpc2/common/legacy_buffer.py deleted file mode 100644 index dbbfea6..0000000 --- a/tdmpc2/common/legacy_buffer.py +++ /dev/null @@ -1,115 +0,0 @@ -from pathlib import Path -import torch -from tensordict.tensordict import TensorDict -from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage -from torchrl.data.replay_buffers.samplers import RandomSampler -from torchrl.envs import RandomCropTensorDict, Transform, Compose - -from common.logger import make_dir - - -class DataPrepTransform(Transform): - """ - Preprocesses data for TD-MPC2 training. - Replay data is expected to be a TensorDict with the following keys: - obs: observations - action: actions - reward: rewards - task: task IDs (optional) - A TensorDict with T time steps has T+1 observations and T actions and rewards. - The first actions and rewards in each TensorDict are dummies and should be ignored. - """ - - def __init__(self): - super().__init__([]) - - def forward(self, td): - td = td.permute(1,0) - return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None) - - -class Buffer(): - """ - Create a replay buffer for TD-MPC2 training. - Uses CUDA memory if available, and CPU memory otherwise. - """ - - def __init__(self, cfg): - self.cfg = cfg - self._device = torch.device('cuda') - self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length - self._num_eps = 0 - - @property - def capacity(self): - """Return the capacity of the buffer.""" - return self._capacity - - @property - def num_eps(self): - """Return the number of episodes in the buffer.""" - return self._num_eps - - def _reserve_buffer(self, storage): - """ - Reserve a buffer with the given storage. - Uses the RandomSampler to sample trajectories, - and the RandomCropTensorDict transform to crop trajectories to the desired length. - DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates. - """ - return ReplayBuffer( - storage=storage, - sampler=RandomSampler(), - pin_memory=True, - prefetch=1, - transform=Compose( - RandomCropTensorDict(self.cfg.horizon+1, -1), - DataPrepTransform(), - ), - batch_size=self.cfg.batch_size, - ) - - def _init(self, tds): - """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" - mem_free, _ = torch.cuda.mem_get_info() - bytes_per_ep = sum([ - (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ - else sum([x.numel()*x.element_size() for x in v.values()])) \ - for k,v in tds.items() - ]) - print(f'Bytes per episode: {bytes_per_ep:,}') - total_bytes = bytes_per_ep*self._capacity - print(f'Storage required: {total_bytes/1e9:.2f} GB') - # Heuristic: decide whether to use CUDA or CPU memory - if 2.5*total_bytes > mem_free: # Insufficient CUDA memory - print('Using CPU memory for storage.') - return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device('cpu')) - ) - else: # Sufficient CUDA memory - print('Using CUDA memory for storage.') - return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device('cuda')) - ) - - def add(self, tds): - """Add an episode to the buffer. All episodes are expected to have the same length.""" - if self._num_eps == 0: - self._buffer = self._init(tds) - self._buffer.add(tds) - self._num_eps += 1 - return self._num_eps - - def sample(self): - """Sample a batch of sub-trajectories from the buffer.""" - obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size) - return obs.to(self._device, non_blocking=True), \ - action.to(self._device, non_blocking=True), \ - reward.to(self._device, non_blocking=True), \ - task.to(self._device, non_blocking=True) if task is not None else None - - def save(self): - """Save the buffer to disk. Useful for storing offline datasets.""" - td = self._buffer._storage._storage.cpu() - fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt' - torch.save(td, fp) diff --git a/tdmpc2/test_buffer.py b/tdmpc2/test_buffer.py new file mode 100644 index 0000000..7a58d08 --- /dev/null +++ b/tdmpc2/test_buffer.py @@ -0,0 +1,64 @@ +import os +os.environ['LAZY_LEGACY_OP'] = '0' + +import torch +import hydra +from tensordict.tensordict import TensorDict + +from common.buffer import CropBuffer, SliceBuffer + + +@hydra.main(config_name='config', config_path='.') +def test_buffer(cfg: dict): + cfg.episode_length = 12 + cfg.batch_size = 8 + + transitions0 = [TensorDict(dict( + obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t, + action=torch.tensor([-1.]) ** t, + reward=torch.tensor([1.]) * t, + ), batch_size=(1,)) for t in range(cfg.episode_length)] + episode0 = torch.cat(transitions0) + + transitions1 = [TensorDict(dict( + obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t, + action=(torch.tensor([-1.]) ** t) * 0.5, + reward=torch.tensor([-1.]) * t, + ), batch_size=(1,)) for t in range(cfg.episode_length)] + episode1 = torch.cat(transitions1) + + crop_buffer = CropBuffer(cfg) + slice_buffer = SliceBuffer(cfg) + + crop_buffer.add(episode0) + slice_buffer.add(episode0) + + crop_buffer.add(episode1) + slice_buffer.add(episode1) + + crop_obs, crop_action, crop_reward, _ = crop_buffer.sample() + slice_obs, slice_action, slice_reward, _ = slice_buffer.sample() + + assert crop_obs.shape == slice_obs.shape + assert crop_action.shape == slice_action.shape + assert crop_reward.shape == slice_reward.shape + + assert (crop_obs[1:] - crop_obs[:-1] == 1.).all() + assert (slice_obs[1:] - slice_obs[:-1] == 1.).all() + assert (crop_action.mean().abs() < 0.2) + assert (slice_action.mean().abs() < 0.2) + + crop_rewards, slice_rewards = [], [] + for _ in range(100_000): + _, _, crop_reward_, _ = crop_buffer.sample() + _, _, slice_reward_, _ = slice_buffer.sample() + crop_rewards.append(crop_reward_.mean()) + slice_rewards.append(slice_reward_.mean()) + + crop_rewards = torch.tensor(crop_rewards).mean() + slice_rewards = torch.tensor(slice_rewards).mean() + assert (crop_rewards - slice_rewards) < 0.1 + + +if __name__ == '__main__': + test_buffer() diff --git a/tdmpc2/train.py b/tdmpc2/train.py index ded21e3..5303e09 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -10,8 +10,7 @@ from termcolor import colored from common.parser import parse_cfg from common.seed import set_seed -from common.buffer import Buffer -# from common.legacy_buffer import Buffer +from common.buffer import CropBuffer, SliceBuffer from envs import make_env from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer @@ -52,7 +51,8 @@ def train(cfg: dict): cfg=cfg, env=make_env(cfg), agent=TDMPC2(cfg), - buffer=Buffer(cfg), + buffer=CropBuffer(cfg), + # buffer=SliceBuffer(cfg), logger=Logger(cfg), ) trainer.train()