unit test buffer implementations

This commit is contained in:
Nicklas Hansen
2023-12-23 07:43:28 -08:00
parent 70fe242adc
commit eef1d1b407
4 changed files with 164 additions and 163 deletions

View File

@@ -1,24 +1,27 @@
from pathlib import Path
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict
from common.logger import make_dir
from common.samplers import SliceSampler from common.samplers import SliceSampler
class Buffer(): class Buffer():
""" """
Create a replay buffer for TD-MPC2 training. Base class for TD-MPC2 replay buffers.
Uses CUDA memory if available, and CPU memory otherwise. Uses CUDA memory if available, and CPU memory otherwise.
""" """
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._device = torch.device('cuda') self._device = torch.device('cuda')
self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1) self._capacity = None
self._capacity = min(cfg.buffer_size, cfg.steps) self._max_eps = None
self._num_eps = 0 self._num_eps = 0
self._sampler = None
self._transform = None
self._batch_size = None
@property @property
def capacity(self): def capacity(self):
@@ -36,15 +39,11 @@ class Buffer():
""" """
return ReplayBuffer( return ReplayBuffer(
storage=storage, storage=storage,
sampler=SliceSampler( sampler=self._sampler,
num_slices=self.cfg.batch_size,
end_key=None,
traj_key='episode',
truncated_key=None,
),
pin_memory=True, pin_memory=True,
prefetch=1, prefetch=1,
batch_size=self.cfg.batch_size, transform=self._transform,
batch_size=self._batch_size,
) )
def _init(self, tds): def _init(self, tds):
@@ -53,45 +52,98 @@ class Buffer():
bytes_per_ep = sum([ bytes_per_ep = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \ else sum([x.numel()*x.element_size() for x in v.values()])) \
for v in tds.values() for k,v in tds.items()
]) ])
print(f'Bytes per episode: {bytes_per_ep:,}') print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep * (self._capacity // self.cfg.episode_length) total_bytes = bytes_per_ep*self._max_eps
print(f'Storage required: {total_bytes/1e9:.2f} GB') print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory # Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
print('Using CPU memory for storage.') print(f'Using {storage_device.upper()} memory for storage.')
return self._reserve_buffer( return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cpu')) LazyTensorStorage(self._capacity, device=torch.device(storage_device))
) )
else: # Sufficient CUDA memory
print('Using CUDA memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
)
def add(self, tds): def _to_device(self, *args, device=None):
"""Add a step to the buffer.""" if device is None:
tds['episode'] = torch.ones_like(tds['reward'], dtype=torch.int64) * self._num_eps device = self._device
tds['step'] = torch.arange(0, len(tds)) return (arg.to(device, non_blocking=True) \
if arg is not None else None for arg in args)
def _prepare_batch(self, td):
"""
Prepare a sampled batch for training (post-processing).
Expects `td` to be a TensorDict with batch size TxB.
"""
obs = td['obs']
action = td['action'][1:]
reward = td['reward'][1:].unsqueeze(-1)
task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, task)
def add(self, td):
"""Add an episode to the buffer."""
pass
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
pass
class CropBuffer(Buffer):
"""
A replay buffer that first samples trajectories, and then crops to desired length.
"""
def __init__(self, cfg):
super().__init__(cfg)
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
self._max_eps = self._capacity
self._sampler = RandomSampler()
self._transform = RandomCropTensorDict(cfg.horizon+1, -1)
self._batch_size = cfg.batch_size
def add(self, td):
"""Add an episode to the buffer. All episodes are expected to be equal length."""
if self._num_eps == 0: if self._num_eps == 0:
self._buffer = self._init(tds) self._buffer = self._init(td)
self._buffer.extend(tds) self._buffer.add(td)
self._num_eps += 1 self._num_eps += 1
return self._num_eps return self._num_eps
def sample(self): def sample(self):
"""Sample a batch of sub-trajectories from the buffer.""" """Sample a batch of subsequences from the buffer."""
td = self._buffer.sample(batch_size=self._batch_size) \ td = self._buffer.sample().permute(1,0)
.view(-1, self.cfg.horizon+1).permute(1, 0) return self._prepare_batch(td)
obs = td['obs'].to(self._device, non_blocking=True)
action = td['action'][1:].to(self._device, non_blocking=True)
reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True)
task = td['task'][0].to(self._device, non_blocking=True) if 'task' in td.keys() else None
return obs, action, reward, task
def save(self):
"""Save the buffer to disk. Useful for storing offline datasets.""" class SliceBuffer(Buffer):
td = self._buffer._storage._storage.cpu() """
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt' A replay buffer that directly samples subsequences. More efficient than CropBuffer.
torch.save(td, fp) """
def __init__(self, cfg):
super().__init__(cfg)
self._capacity = min(cfg.buffer_size, cfg.steps)
self._max_eps = self._capacity//cfg.episode_length
self._sampler = SliceSampler(
num_slices=self.cfg.batch_size,
end_key=None,
traj_key='episode',
truncated_key=None,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
def add(self, td):
"""Add an episode to the buffer. Supports variable episode lengths."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
if self._num_eps == 0:
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of subsequences from the buffer."""
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
return self._prepare_batch(td)

View File

@@ -1,115 +0,0 @@
from pathlib import Path
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict, Transform, Compose
from common.logger import make_dir
class DataPrepTransform(Transform):
"""
Preprocesses data for TD-MPC2 training.
Replay data is expected to be a TensorDict with the following keys:
obs: observations
action: actions
reward: rewards
task: task IDs (optional)
A TensorDict with T time steps has T+1 observations and T actions and rewards.
The first actions and rewards in each TensorDict are dummies and should be ignored.
"""
def __init__(self):
super().__init__([])
def forward(self, td):
td = td.permute(1,0)
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
class Buffer():
"""
Create a replay buffer for TD-MPC2 training.
Uses CUDA memory if available, and CPU memory otherwise.
"""
def __init__(self, cfg):
self.cfg = cfg
self._device = torch.device('cuda')
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
self._num_eps = 0
@property
def capacity(self):
"""Return the capacity of the buffer."""
return self._capacity
@property
def num_eps(self):
"""Return the number of episodes in the buffer."""
return self._num_eps
def _reserve_buffer(self, storage):
"""
Reserve a buffer with the given storage.
Uses the RandomSampler to sample trajectories,
and the RandomCropTensorDict transform to crop trajectories to the desired length.
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
"""
return ReplayBuffer(
storage=storage,
sampler=RandomSampler(),
pin_memory=True,
prefetch=1,
transform=Compose(
RandomCropTensorDict(self.cfg.horizon+1, -1),
DataPrepTransform(),
),
batch_size=self.cfg.batch_size,
)
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
print('Using CPU memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
)
else: # Sufficient CUDA memory
print('Using CUDA memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
)
def add(self, tds):
"""Add an episode to the buffer. All episodes are expected to have the same length."""
if self._num_eps == 0:
self._buffer = self._init(tds)
self._buffer.add(tds)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
return obs.to(self._device, non_blocking=True), \
action.to(self._device, non_blocking=True), \
reward.to(self._device, non_blocking=True), \
task.to(self._device, non_blocking=True) if task is not None else None
def save(self):
"""Save the buffer to disk. Useful for storing offline datasets."""
td = self._buffer._storage._storage.cpu()
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
torch.save(td, fp)

64
tdmpc2/test_buffer.py Normal file
View File

@@ -0,0 +1,64 @@
import os
os.environ['LAZY_LEGACY_OP'] = '0'
import torch
import hydra
from tensordict.tensordict import TensorDict
from common.buffer import CropBuffer, SliceBuffer
@hydra.main(config_name='config', config_path='.')
def test_buffer(cfg: dict):
cfg.episode_length = 12
cfg.batch_size = 8
transitions0 = [TensorDict(dict(
obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t,
action=torch.tensor([-1.]) ** t,
reward=torch.tensor([1.]) * t,
), batch_size=(1,)) for t in range(cfg.episode_length)]
episode0 = torch.cat(transitions0)
transitions1 = [TensorDict(dict(
obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t,
action=(torch.tensor([-1.]) ** t) * 0.5,
reward=torch.tensor([-1.]) * t,
), batch_size=(1,)) for t in range(cfg.episode_length)]
episode1 = torch.cat(transitions1)
crop_buffer = CropBuffer(cfg)
slice_buffer = SliceBuffer(cfg)
crop_buffer.add(episode0)
slice_buffer.add(episode0)
crop_buffer.add(episode1)
slice_buffer.add(episode1)
crop_obs, crop_action, crop_reward, _ = crop_buffer.sample()
slice_obs, slice_action, slice_reward, _ = slice_buffer.sample()
assert crop_obs.shape == slice_obs.shape
assert crop_action.shape == slice_action.shape
assert crop_reward.shape == slice_reward.shape
assert (crop_obs[1:] - crop_obs[:-1] == 1.).all()
assert (slice_obs[1:] - slice_obs[:-1] == 1.).all()
assert (crop_action.mean().abs() < 0.2)
assert (slice_action.mean().abs() < 0.2)
crop_rewards, slice_rewards = [], []
for _ in range(100_000):
_, _, crop_reward_, _ = crop_buffer.sample()
_, _, slice_reward_, _ = slice_buffer.sample()
crop_rewards.append(crop_reward_.mean())
slice_rewards.append(slice_reward_.mean())
crop_rewards = torch.tensor(crop_rewards).mean()
slice_rewards = torch.tensor(slice_rewards).mean()
assert (crop_rewards - slice_rewards) < 0.1
if __name__ == '__main__':
test_buffer()

View File

@@ -10,8 +10,7 @@ from termcolor import colored
from common.parser import parse_cfg from common.parser import parse_cfg
from common.seed import set_seed from common.seed import set_seed
from common.buffer import Buffer from common.buffer import CropBuffer, SliceBuffer
# from common.legacy_buffer import Buffer
from envs import make_env from envs import make_env
from tdmpc2 import TDMPC2 from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer from trainer.offline_trainer import OfflineTrainer
@@ -52,7 +51,8 @@ def train(cfg: dict):
cfg=cfg, cfg=cfg,
env=make_env(cfg), env=make_env(cfg),
agent=TDMPC2(cfg), agent=TDMPC2(cfg),
buffer=Buffer(cfg), buffer=CropBuffer(cfg),
# buffer=SliceBuffer(cfg),
logger=Logger(cfg), logger=Logger(cfg),
) )
trainer.train() trainer.train()