unit test buffer implementations
This commit is contained in:
@@ -1,24 +1,27 @@
|
|||||||
from pathlib import Path
|
|
||||||
import torch
|
import torch
|
||||||
from tensordict.tensordict import TensorDict
|
from tensordict.tensordict import TensorDict
|
||||||
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
||||||
|
from torchrl.data.replay_buffers.samplers import RandomSampler
|
||||||
|
from torchrl.envs import RandomCropTensorDict
|
||||||
|
|
||||||
from common.logger import make_dir
|
|
||||||
from common.samplers import SliceSampler
|
from common.samplers import SliceSampler
|
||||||
|
|
||||||
|
|
||||||
class Buffer():
|
class Buffer():
|
||||||
"""
|
"""
|
||||||
Create a replay buffer for TD-MPC2 training.
|
Base class for TD-MPC2 replay buffers.
|
||||||
Uses CUDA memory if available, and CPU memory otherwise.
|
Uses CUDA memory if available, and CPU memory otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._device = torch.device('cuda')
|
self._device = torch.device('cuda')
|
||||||
self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1)
|
self._capacity = None
|
||||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
self._max_eps = None
|
||||||
self._num_eps = 0
|
self._num_eps = 0
|
||||||
|
self._sampler = None
|
||||||
|
self._transform = None
|
||||||
|
self._batch_size = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def capacity(self):
|
def capacity(self):
|
||||||
@@ -36,15 +39,11 @@ class Buffer():
|
|||||||
"""
|
"""
|
||||||
return ReplayBuffer(
|
return ReplayBuffer(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
sampler=SliceSampler(
|
sampler=self._sampler,
|
||||||
num_slices=self.cfg.batch_size,
|
|
||||||
end_key=None,
|
|
||||||
traj_key='episode',
|
|
||||||
truncated_key=None,
|
|
||||||
),
|
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
prefetch=1,
|
prefetch=1,
|
||||||
batch_size=self.cfg.batch_size,
|
transform=self._transform,
|
||||||
|
batch_size=self._batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init(self, tds):
|
def _init(self, tds):
|
||||||
@@ -53,45 +52,98 @@ class Buffer():
|
|||||||
bytes_per_ep = sum([
|
bytes_per_ep = sum([
|
||||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||||
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
||||||
for v in tds.values()
|
for k,v in tds.items()
|
||||||
])
|
])
|
||||||
print(f'Bytes per episode: {bytes_per_ep:,}')
|
print(f'Bytes per episode: {bytes_per_ep:,}')
|
||||||
total_bytes = bytes_per_ep * (self._capacity // self.cfg.episode_length)
|
total_bytes = bytes_per_ep*self._max_eps
|
||||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||||
# Heuristic: decide whether to use CUDA or CPU memory
|
# Heuristic: decide whether to use CUDA or CPU memory
|
||||||
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
|
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
|
||||||
print('Using CPU memory for storage.')
|
print(f'Using {storage_device.upper()} memory for storage.')
|
||||||
return self._reserve_buffer(
|
return self._reserve_buffer(
|
||||||
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
|
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
|
||||||
)
|
)
|
||||||
else: # Sufficient CUDA memory
|
|
||||||
print('Using CUDA memory for storage.')
|
|
||||||
return self._reserve_buffer(
|
|
||||||
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
|
|
||||||
)
|
|
||||||
|
|
||||||
def add(self, tds):
|
def _to_device(self, *args, device=None):
|
||||||
"""Add a step to the buffer."""
|
if device is None:
|
||||||
tds['episode'] = torch.ones_like(tds['reward'], dtype=torch.int64) * self._num_eps
|
device = self._device
|
||||||
tds['step'] = torch.arange(0, len(tds))
|
return (arg.to(device, non_blocking=True) \
|
||||||
|
if arg is not None else None for arg in args)
|
||||||
|
|
||||||
|
def _prepare_batch(self, td):
|
||||||
|
"""
|
||||||
|
Prepare a sampled batch for training (post-processing).
|
||||||
|
Expects `td` to be a TensorDict with batch size TxB.
|
||||||
|
"""
|
||||||
|
obs = td['obs']
|
||||||
|
action = td['action'][1:]
|
||||||
|
reward = td['reward'][1:].unsqueeze(-1)
|
||||||
|
task = td['task'][0] if 'task' in td.keys() else None
|
||||||
|
return self._to_device(obs, action, reward, task)
|
||||||
|
|
||||||
|
def add(self, td):
|
||||||
|
"""Add an episode to the buffer."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
"""Sample a batch of sub-trajectories from the buffer."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CropBuffer(Buffer):
|
||||||
|
"""
|
||||||
|
A replay buffer that first samples trajectories, and then crops to desired length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
|
||||||
|
self._max_eps = self._capacity
|
||||||
|
self._sampler = RandomSampler()
|
||||||
|
self._transform = RandomCropTensorDict(cfg.horizon+1, -1)
|
||||||
|
self._batch_size = cfg.batch_size
|
||||||
|
|
||||||
|
def add(self, td):
|
||||||
|
"""Add an episode to the buffer. All episodes are expected to be equal length."""
|
||||||
if self._num_eps == 0:
|
if self._num_eps == 0:
|
||||||
self._buffer = self._init(tds)
|
self._buffer = self._init(td)
|
||||||
self._buffer.extend(tds)
|
self._buffer.add(td)
|
||||||
self._num_eps += 1
|
self._num_eps += 1
|
||||||
return self._num_eps
|
return self._num_eps
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
"""Sample a batch of sub-trajectories from the buffer."""
|
"""Sample a batch of subsequences from the buffer."""
|
||||||
td = self._buffer.sample(batch_size=self._batch_size) \
|
td = self._buffer.sample().permute(1,0)
|
||||||
.view(-1, self.cfg.horizon+1).permute(1, 0)
|
return self._prepare_batch(td)
|
||||||
obs = td['obs'].to(self._device, non_blocking=True)
|
|
||||||
action = td['action'][1:].to(self._device, non_blocking=True)
|
|
||||||
reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True)
|
class SliceBuffer(Buffer):
|
||||||
task = td['task'][0].to(self._device, non_blocking=True) if 'task' in td.keys() else None
|
"""
|
||||||
return obs, action, reward, task
|
A replay buffer that directly samples subsequences. More efficient than CropBuffer.
|
||||||
|
"""
|
||||||
def save(self):
|
|
||||||
"""Save the buffer to disk. Useful for storing offline datasets."""
|
def __init__(self, cfg):
|
||||||
td = self._buffer._storage._storage.cpu()
|
super().__init__(cfg)
|
||||||
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
|
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||||
torch.save(td, fp)
|
self._max_eps = self._capacity//cfg.episode_length
|
||||||
|
self._sampler = SliceSampler(
|
||||||
|
num_slices=self.cfg.batch_size,
|
||||||
|
end_key=None,
|
||||||
|
traj_key='episode',
|
||||||
|
truncated_key=None,
|
||||||
|
)
|
||||||
|
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||||
|
|
||||||
|
def add(self, td):
|
||||||
|
"""Add an episode to the buffer. Supports variable episode lengths."""
|
||||||
|
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
|
||||||
|
if self._num_eps == 0:
|
||||||
|
self._buffer = self._init(td)
|
||||||
|
self._buffer.extend(td)
|
||||||
|
self._num_eps += 1
|
||||||
|
return self._num_eps
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
"""Sample a batch of subsequences from the buffer."""
|
||||||
|
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
||||||
|
return self._prepare_batch(td)
|
||||||
|
|||||||
@@ -1,115 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
import torch
|
|
||||||
from tensordict.tensordict import TensorDict
|
|
||||||
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
|
||||||
from torchrl.data.replay_buffers.samplers import RandomSampler
|
|
||||||
from torchrl.envs import RandomCropTensorDict, Transform, Compose
|
|
||||||
|
|
||||||
from common.logger import make_dir
|
|
||||||
|
|
||||||
|
|
||||||
class DataPrepTransform(Transform):
|
|
||||||
"""
|
|
||||||
Preprocesses data for TD-MPC2 training.
|
|
||||||
Replay data is expected to be a TensorDict with the following keys:
|
|
||||||
obs: observations
|
|
||||||
action: actions
|
|
||||||
reward: rewards
|
|
||||||
task: task IDs (optional)
|
|
||||||
A TensorDict with T time steps has T+1 observations and T actions and rewards.
|
|
||||||
The first actions and rewards in each TensorDict are dummies and should be ignored.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__([])
|
|
||||||
|
|
||||||
def forward(self, td):
|
|
||||||
td = td.permute(1,0)
|
|
||||||
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
|
|
||||||
|
|
||||||
|
|
||||||
class Buffer():
|
|
||||||
"""
|
|
||||||
Create a replay buffer for TD-MPC2 training.
|
|
||||||
Uses CUDA memory if available, and CPU memory otherwise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
self.cfg = cfg
|
|
||||||
self._device = torch.device('cuda')
|
|
||||||
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
|
|
||||||
self._num_eps = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def capacity(self):
|
|
||||||
"""Return the capacity of the buffer."""
|
|
||||||
return self._capacity
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_eps(self):
|
|
||||||
"""Return the number of episodes in the buffer."""
|
|
||||||
return self._num_eps
|
|
||||||
|
|
||||||
def _reserve_buffer(self, storage):
|
|
||||||
"""
|
|
||||||
Reserve a buffer with the given storage.
|
|
||||||
Uses the RandomSampler to sample trajectories,
|
|
||||||
and the RandomCropTensorDict transform to crop trajectories to the desired length.
|
|
||||||
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
|
|
||||||
"""
|
|
||||||
return ReplayBuffer(
|
|
||||||
storage=storage,
|
|
||||||
sampler=RandomSampler(),
|
|
||||||
pin_memory=True,
|
|
||||||
prefetch=1,
|
|
||||||
transform=Compose(
|
|
||||||
RandomCropTensorDict(self.cfg.horizon+1, -1),
|
|
||||||
DataPrepTransform(),
|
|
||||||
),
|
|
||||||
batch_size=self.cfg.batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init(self, tds):
|
|
||||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
|
||||||
mem_free, _ = torch.cuda.mem_get_info()
|
|
||||||
bytes_per_ep = sum([
|
|
||||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
|
||||||
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
|
||||||
for k,v in tds.items()
|
|
||||||
])
|
|
||||||
print(f'Bytes per episode: {bytes_per_ep:,}')
|
|
||||||
total_bytes = bytes_per_ep*self._capacity
|
|
||||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
|
||||||
# Heuristic: decide whether to use CUDA or CPU memory
|
|
||||||
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
|
|
||||||
print('Using CPU memory for storage.')
|
|
||||||
return self._reserve_buffer(
|
|
||||||
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
|
|
||||||
)
|
|
||||||
else: # Sufficient CUDA memory
|
|
||||||
print('Using CUDA memory for storage.')
|
|
||||||
return self._reserve_buffer(
|
|
||||||
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
|
|
||||||
)
|
|
||||||
|
|
||||||
def add(self, tds):
|
|
||||||
"""Add an episode to the buffer. All episodes are expected to have the same length."""
|
|
||||||
if self._num_eps == 0:
|
|
||||||
self._buffer = self._init(tds)
|
|
||||||
self._buffer.add(tds)
|
|
||||||
self._num_eps += 1
|
|
||||||
return self._num_eps
|
|
||||||
|
|
||||||
def sample(self):
|
|
||||||
"""Sample a batch of sub-trajectories from the buffer."""
|
|
||||||
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
|
|
||||||
return obs.to(self._device, non_blocking=True), \
|
|
||||||
action.to(self._device, non_blocking=True), \
|
|
||||||
reward.to(self._device, non_blocking=True), \
|
|
||||||
task.to(self._device, non_blocking=True) if task is not None else None
|
|
||||||
|
|
||||||
def save(self):
|
|
||||||
"""Save the buffer to disk. Useful for storing offline datasets."""
|
|
||||||
td = self._buffer._storage._storage.cpu()
|
|
||||||
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
|
|
||||||
torch.save(td, fp)
|
|
||||||
64
tdmpc2/test_buffer.py
Normal file
64
tdmpc2/test_buffer.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import os
|
||||||
|
os.environ['LAZY_LEGACY_OP'] = '0'
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import hydra
|
||||||
|
from tensordict.tensordict import TensorDict
|
||||||
|
|
||||||
|
from common.buffer import CropBuffer, SliceBuffer
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_name='config', config_path='.')
|
||||||
|
def test_buffer(cfg: dict):
|
||||||
|
cfg.episode_length = 12
|
||||||
|
cfg.batch_size = 8
|
||||||
|
|
||||||
|
transitions0 = [TensorDict(dict(
|
||||||
|
obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t,
|
||||||
|
action=torch.tensor([-1.]) ** t,
|
||||||
|
reward=torch.tensor([1.]) * t,
|
||||||
|
), batch_size=(1,)) for t in range(cfg.episode_length)]
|
||||||
|
episode0 = torch.cat(transitions0)
|
||||||
|
|
||||||
|
transitions1 = [TensorDict(dict(
|
||||||
|
obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t,
|
||||||
|
action=(torch.tensor([-1.]) ** t) * 0.5,
|
||||||
|
reward=torch.tensor([-1.]) * t,
|
||||||
|
), batch_size=(1,)) for t in range(cfg.episode_length)]
|
||||||
|
episode1 = torch.cat(transitions1)
|
||||||
|
|
||||||
|
crop_buffer = CropBuffer(cfg)
|
||||||
|
slice_buffer = SliceBuffer(cfg)
|
||||||
|
|
||||||
|
crop_buffer.add(episode0)
|
||||||
|
slice_buffer.add(episode0)
|
||||||
|
|
||||||
|
crop_buffer.add(episode1)
|
||||||
|
slice_buffer.add(episode1)
|
||||||
|
|
||||||
|
crop_obs, crop_action, crop_reward, _ = crop_buffer.sample()
|
||||||
|
slice_obs, slice_action, slice_reward, _ = slice_buffer.sample()
|
||||||
|
|
||||||
|
assert crop_obs.shape == slice_obs.shape
|
||||||
|
assert crop_action.shape == slice_action.shape
|
||||||
|
assert crop_reward.shape == slice_reward.shape
|
||||||
|
|
||||||
|
assert (crop_obs[1:] - crop_obs[:-1] == 1.).all()
|
||||||
|
assert (slice_obs[1:] - slice_obs[:-1] == 1.).all()
|
||||||
|
assert (crop_action.mean().abs() < 0.2)
|
||||||
|
assert (slice_action.mean().abs() < 0.2)
|
||||||
|
|
||||||
|
crop_rewards, slice_rewards = [], []
|
||||||
|
for _ in range(100_000):
|
||||||
|
_, _, crop_reward_, _ = crop_buffer.sample()
|
||||||
|
_, _, slice_reward_, _ = slice_buffer.sample()
|
||||||
|
crop_rewards.append(crop_reward_.mean())
|
||||||
|
slice_rewards.append(slice_reward_.mean())
|
||||||
|
|
||||||
|
crop_rewards = torch.tensor(crop_rewards).mean()
|
||||||
|
slice_rewards = torch.tensor(slice_rewards).mean()
|
||||||
|
assert (crop_rewards - slice_rewards) < 0.1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_buffer()
|
||||||
@@ -10,8 +10,7 @@ from termcolor import colored
|
|||||||
|
|
||||||
from common.parser import parse_cfg
|
from common.parser import parse_cfg
|
||||||
from common.seed import set_seed
|
from common.seed import set_seed
|
||||||
from common.buffer import Buffer
|
from common.buffer import CropBuffer, SliceBuffer
|
||||||
# from common.legacy_buffer import Buffer
|
|
||||||
from envs import make_env
|
from envs import make_env
|
||||||
from tdmpc2 import TDMPC2
|
from tdmpc2 import TDMPC2
|
||||||
from trainer.offline_trainer import OfflineTrainer
|
from trainer.offline_trainer import OfflineTrainer
|
||||||
@@ -52,7 +51,8 @@ def train(cfg: dict):
|
|||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
env=make_env(cfg),
|
env=make_env(cfg),
|
||||||
agent=TDMPC2(cfg),
|
agent=TDMPC2(cfg),
|
||||||
buffer=Buffer(cfg),
|
buffer=CropBuffer(cfg),
|
||||||
|
# buffer=SliceBuffer(cfg),
|
||||||
logger=Logger(cfg),
|
logger=Logger(cfg),
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user