116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
import torch
|
|
from tensordict.tensordict import TensorDict
|
|
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
|
from torchrl.data.replay_buffers.samplers import SliceSampler
|
|
|
|
|
|
class Buffer():
|
|
"""
|
|
Replay buffer for TD-MPC2 training. Based on torchrl.
|
|
Uses CUDA memory if available, and CPU memory otherwise.
|
|
"""
|
|
|
|
def __init__(self, cfg):
|
|
self.cfg = cfg
|
|
self._device = torch.device('cuda:0')
|
|
self._capacity = min(cfg.buffer_size, cfg.steps)
|
|
self._sampler = SliceSampler(
|
|
num_slices=self.cfg.batch_size,
|
|
end_key=None,
|
|
traj_key='episode',
|
|
truncated_key=None,
|
|
strict_length=True,
|
|
cache_values=cfg.multitask,
|
|
)
|
|
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
|
self._num_eps = 0
|
|
|
|
@property
|
|
def capacity(self):
|
|
"""Return the capacity of the buffer."""
|
|
return self._capacity
|
|
|
|
@property
|
|
def num_eps(self):
|
|
"""Return the number of episodes in the buffer."""
|
|
return self._num_eps
|
|
|
|
def _reserve_buffer(self, storage):
|
|
"""
|
|
Reserve a buffer with the given storage.
|
|
"""
|
|
return ReplayBuffer(
|
|
storage=storage,
|
|
sampler=self._sampler,
|
|
pin_memory=False,
|
|
prefetch=0,
|
|
batch_size=self._batch_size,
|
|
)
|
|
|
|
def _init(self, tds):
|
|
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
|
print(f'Buffer capacity: {self._capacity:,}')
|
|
mem_free, _ = torch.cuda.mem_get_info()
|
|
bytes_per_step = sum([
|
|
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
|
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
|
for v in tds.values()
|
|
]) / len(tds)
|
|
total_bytes = bytes_per_step*self._capacity
|
|
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
|
# Heuristic: decide whether to use CUDA or CPU memory
|
|
storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu'
|
|
print(f'Using {storage_device.upper()} memory for storage.')
|
|
self._storage_device = torch.device(storage_device)
|
|
return self._reserve_buffer(
|
|
LazyTensorStorage(self._capacity, device=self._storage_device)
|
|
)
|
|
|
|
def load(self, td):
|
|
"""
|
|
Load a batch of episodes into the buffer. This is useful for loading data from disk,
|
|
and is more efficient than adding episodes one by one.
|
|
"""
|
|
num_new_eps = len(td)
|
|
episode_idx = torch.arange(self._num_eps, self._num_eps+num_new_eps, dtype=torch.int64)
|
|
td['episode'] = episode_idx.unsqueeze(-1).expand(-1, td['reward'].shape[1])
|
|
if self._num_eps == 0:
|
|
self._buffer = self._init(td[0])
|
|
td = td.reshape(td.shape[0]*td.shape[1])
|
|
self._buffer.extend(td)
|
|
self._num_eps += num_new_eps
|
|
return self._num_eps
|
|
|
|
def add(self, td):
|
|
"""Add an episode to the buffer."""
|
|
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
|
|
if self._num_eps == 0:
|
|
self._buffer = self._init(td)
|
|
self._buffer.extend(td)
|
|
self._num_eps += 1
|
|
return self._num_eps
|
|
|
|
def _prepare_batch(self, td):
|
|
"""
|
|
Prepare a sampled batch for training (post-processing).
|
|
Expects `td` to be a TensorDict with batch size TxB.
|
|
"""
|
|
td = td.select("obs", "action", "reward", "terminated", "task", strict=False).to(self._device, non_blocking=True)
|
|
obs = td.get('obs').contiguous()
|
|
action = td.get('action')[1:].contiguous()
|
|
reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
|
|
terminated = td.get('terminated', None)
|
|
if terminated is not None:
|
|
terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
|
|
else:
|
|
terminated = torch.zeros_like(reward)
|
|
task = td.get('task', None)
|
|
if task is not None:
|
|
task = task[0].contiguous()
|
|
return obs, action, reward, terminated, task
|
|
|
|
def sample(self):
|
|
"""Sample a batch of subsequences from the buffer."""
|
|
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
|
return self._prepare_batch(td)
|