98 lines
2.9 KiB
Python
98 lines
2.9 KiB
Python
import torch
|
|
from tensordict.tensordict import TensorDict
|
|
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
|
from torchrl.data.replay_buffers.samplers import SliceSampler
|
|
|
|
|
|
class Buffer():
|
|
"""
|
|
Replay buffer for TD-MPC2 training. Based on torchrl.
|
|
Uses CUDA memory if available, and CPU memory otherwise.
|
|
"""
|
|
|
|
def __init__(self, cfg):
|
|
self.cfg = cfg
|
|
self._device = torch.device('cuda')
|
|
self._capacity = min(cfg.buffer_size, cfg.steps)
|
|
self._sampler = SliceSampler(
|
|
num_slices=self.cfg.batch_size,
|
|
end_key=None,
|
|
traj_key='episode',
|
|
truncated_key=None,
|
|
strict_length=True,
|
|
)
|
|
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
|
self._num_eps = 0
|
|
|
|
@property
|
|
def capacity(self):
|
|
"""Return the capacity of the buffer."""
|
|
return self._capacity
|
|
|
|
@property
|
|
def num_eps(self):
|
|
"""Return the number of episodes in the buffer."""
|
|
return self._num_eps
|
|
|
|
def _reserve_buffer(self, storage):
|
|
"""
|
|
Reserve a buffer with the given storage.
|
|
"""
|
|
return ReplayBuffer(
|
|
storage=storage,
|
|
sampler=self._sampler,
|
|
pin_memory=True,
|
|
prefetch=1,
|
|
batch_size=self._batch_size,
|
|
)
|
|
|
|
def _init(self, tds):
|
|
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
|
print(f'Buffer capacity: {self._capacity:,}')
|
|
mem_free, _ = torch.cuda.mem_get_info()
|
|
bytes_per_step = sum([
|
|
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
|
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
|
for v in tds.values()
|
|
]) / len(tds)
|
|
total_bytes = bytes_per_step*self._capacity
|
|
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
|
# Heuristic: decide whether to use CUDA or CPU memory
|
|
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
|
|
print(f'Using {storage_device.upper()} memory for storage.')
|
|
return self._reserve_buffer(
|
|
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
|
|
)
|
|
|
|
def _to_device(self, *args, device=None):
|
|
if device is None:
|
|
device = self._device
|
|
return (arg.to(device, non_blocking=True) \
|
|
if arg is not None else None for arg in args)
|
|
|
|
def _prepare_batch(self, td):
|
|
"""
|
|
Prepare a sampled batch for training (post-processing).
|
|
Expects `td` to be a TensorDict with batch size TxB.
|
|
"""
|
|
obs = td['obs']
|
|
action = td['action'][1:]
|
|
reward = td['reward'][1:].unsqueeze(-1)
|
|
terminated = td['terminated'][-1].unsqueeze(-1)
|
|
task = td['task'][0] if 'task' in td.keys() else None
|
|
return self._to_device(obs, action, reward, terminated, task)
|
|
|
|
def add(self, td):
|
|
"""Add an episode to the buffer."""
|
|
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
|
|
if self._num_eps == 0:
|
|
self._buffer = self._init(td)
|
|
self._buffer.extend(td)
|
|
self._num_eps += 1
|
|
return self._num_eps
|
|
|
|
def sample(self):
|
|
"""Sample a batch of subsequences from the buffer."""
|
|
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
|
return self._prepare_batch(td)
|