faster replay buffer implementation

This commit is contained in:
Nicklas Hansen
2023-12-22 05:55:43 -08:00
parent 445af9d81d
commit 3ded0ebc83
5 changed files with 428 additions and 68 deletions

View File

@@ -2,30 +2,10 @@ from pathlib import Path
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict, Transform, Compose from torchrl.envs import RandomCropTensorDict, Transform, Compose
from common.logger import make_dir from common.logger import make_dir
from common.samplers import SliceSampler
class DataPrepTransform(Transform):
"""
Preprocesses data for TD-MPC2 training.
Replay data is expected to be a TensorDict with the following keys:
obs: observations
action: actions
reward: rewards
task: task IDs (optional)
A TensorDict with T time steps has T+1 observations and T actions and rewards.
The first actions and rewards in each TensorDict are dummies and should be ignored.
"""
def __init__(self):
super().__init__([])
def forward(self, td):
td = td.permute(1,0)
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
class Buffer(): class Buffer():
@@ -37,7 +17,9 @@ class Buffer():
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._device = torch.device('cuda') self._device = torch.device('cuda')
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1)
self._capacity = min(cfg.buffer_size, cfg.steps)
self._num_steps = 0
self._num_eps = 0 self._num_eps = 0
@property @property
@@ -45,6 +27,11 @@ class Buffer():
"""Return the capacity of the buffer.""" """Return the capacity of the buffer."""
return self._capacity return self._capacity
@property
def num_steps(self):
"""Return the number of steps in the buffer."""
return self._num_steps
@property @property
def num_eps(self): def num_eps(self):
"""Return the number of episodes in the buffer.""" """Return the number of episodes in the buffer."""
@@ -53,32 +40,25 @@ class Buffer():
def _reserve_buffer(self, storage): def _reserve_buffer(self, storage):
""" """
Reserve a buffer with the given storage. Reserve a buffer with the given storage.
Uses the RandomSampler to sample trajectories,
and the RandomCropTensorDict transform to crop trajectories to the desired length.
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
""" """
return ReplayBuffer( return ReplayBuffer(
storage=storage, storage=storage,
sampler=RandomSampler(), sampler=SliceSampler(
pin_memory=True, slice_len=self.cfg.horizon+1,
prefetch=1, end_key='done',
transform=Compose( truncated_key=None,
RandomCropTensorDict(self.cfg.horizon+1, -1),
DataPrepTransform(),
), ),
pin_memory=True,
prefetch=2,
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
) )
def _init(self, tds): def _init(self, td):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements.""" """Initialize the replay buffer. Use the first episode to estimate storage requirements."""
mem_free, _ = torch.cuda.mem_get_info() mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([ bytes_per_step = sum([x.numel()*x.element_size() for x in td[0].values()])
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \ print(f'Bytes per step: {bytes_per_step:,}')
else sum([x.numel()*x.element_size() for x in v.values()])) \ total_bytes = bytes_per_step*self._capacity
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB') print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory # Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
@@ -92,21 +72,28 @@ class Buffer():
LazyTensorStorage(self._capacity, device=torch.device('cuda')) LazyTensorStorage(self._capacity, device=torch.device('cuda'))
) )
def add(self, tds): def add(self, td):
"""Add an episode to the buffer. All episodes are expected to have the same length.""" """Add a step to the buffer."""
if self._num_eps == 0: done = bool(td['done'].any())
self._buffer = self._init(tds) if done:
self._buffer.add(tds)
self._num_eps +=1 self._num_eps +=1
return self._num_eps td['episode'] = torch.ones_like(td['done']) * self._num_eps
td['step'] = torch.arange(0, len(td))
if self._num_steps == 0:
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_steps += 1
return self._num_steps
def sample(self): def sample(self):
"""Sample a batch of sub-trajectories from the buffer.""" """Sample a batch of sub-trajectories from the buffer."""
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size) td = self._buffer.sample(batch_size=self._batch_size) \
return obs.to(self._device, non_blocking=True), \ .reshape(-1, self.cfg.horizon+1).permute(1, 0)
action.to(self._device, non_blocking=True), \ obs = td['obs'].to(self._device, non_blocking=True)
reward.to(self._device, non_blocking=True), \ action = td['action'][1:].to(self._device, non_blocking=True)
task.to(self._device, non_blocking=True) if task is not None else None reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True)
task = td['task'][0].to(self._device, non_blocking=True) if 'task' in td.keys() else None
return obs, action, reward, task
def save(self): def save(self):
"""Save the buffer to disk. Useful for storing offline datasets.""" """Save the buffer to disk. Useful for storing offline datasets."""

365
tdmpc2/common/samplers.py Normal file
View File

@@ -0,0 +1,365 @@
from __future__ import annotations
import json
import warnings
from abc import ABC, abstractmethod
from copy import copy, deepcopy
from multiprocessing.context import get_spawning_popen
from pathlib import Path
from typing import Any, Dict, Tuple, Union
import numpy as np
import torch
from tensordict import MemoryMappedTensor
from tensordict.utils import NestedKey
from torchrl._extension import EXTENSION_WARNING
try:
from torchrl._torchrl import (
MinSegmentTreeFp32,
MinSegmentTreeFp64,
SumSegmentTreeFp32,
SumSegmentTreeFp64,
)
except ImportError:
warnings.warn(EXTENSION_WARNING)
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES
from torchrl.data.replay_buffers.samplers import Sampler
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
class SliceSampler(Sampler):
"""Samples slices of data along the first dimension, given start and stop signals.
This class samples sub-trajectories with replacement. For a version without
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
with ``slice_len``.
slice_len (int): the length of the slices to be sampled. The batch-size
must be greater or equal to the ``slice_len`` argument and divisible
by it. Exclusive with ``num_slices``.
end_key (NestedKey, optional): the key indicating the end of a
trajectory (or episode). Defaults to ``("next", "done")``.
traj_key (NestedKey, optional): the key indicating the trajectories.
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
trajectory breaks. Defaults to ``("next", "truncated")``.
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
instances (otherwise the truncated key is returned in the info dictionary
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
attempt to find the ``traj_key`` entry in the storage. If it cannot be
found, the ``end_key`` will be used to reconstruct the episodes.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
... storage=LazyMemmapStorage(1_000_000),
... sampler=SliceSampler(cache_values=True, num_slices=10),
... batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
... {
... "episode": episode,
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
... "act": torch.randn((20,)).expand(1000, 20),
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
... }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
most of TorchRL's datasets:
Examples:
>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
... batch = batch.reshape(num_slices, -1)
... break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
[14],
[ 8],
[10],
[13],
[ 4],
[ 2],
[ 3],
[22],
[ 8]])
"""
def __init__(
self,
*,
num_slices: int = None,
slice_len: int = None,
end_key: NestedKey | None = None,
traj_key: NestedKey | None = None,
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
) -> object:
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
traj_key = "episode"
if not ((num_slices is None) ^ (slice_len is None)):
raise TypeError(
"Either num_slices or slice_len must be not None, and not both. "
f"Got num_slices={num_slices} and slice_len={slice_len}."
)
self.num_slices = num_slices
self.slice_len = slice_len
self.end_key = end_key
self.traj_key = traj_key
self.truncated_key = truncated_key
self.cache_values = cache_values
self._fetch_traj = True
self._uses_data_prefix = False
self.strict_length = strict_length
self._cache = {}
@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
# stop_idx = stop_idx.cumsum(0) - 1
# even slower
# t = trajectory.unsqueeze(0)
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
# stop_idx = torch.conv1d(t, w).nonzero()
# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
if end.ndim != 1:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
)
stop_idx = end.view(-1).nonzero().view(-1)
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
lengths = stop_idx - start_idx + 1
return start_idx, stop_idx, lengths
def _tensor_slices_from_startend(self, seq_length, start):
if isinstance(seq_length, int):
return (
torch.arange(
seq_length, device=start.device, dtype=start.dtype
).unsqueeze(0)
+ start.unsqueeze(1)
).view(-1)
else:
# when padding is needed
return torch.cat(
[
_start
+ torch.arange(_seq_len, device=start.device, dtype=start.dtype)
for _start, _seq_len in zip(start, seq_length)
]
)
def _get_stop_and_length(self, storage, fallback=True):
if self.cache_values and "stop-and-length" in self._cache:
return self._cache.get("stop-and-length")
if self._fetch_traj:
# We first try with the traj_key
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
trajectory = storage._storage.get(self._used_traj_key)
except KeyError:
trajectory = storage._storage.get(("_data", self.traj_key))
# cache that value for future use
self._used_traj_key = ("_data", self.traj_key)
self._uses_data_prefix = (
isinstance(self._used_traj_key, tuple)
and self._used_traj_key[0] == "_data"
)
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
return self._cache.setdefault("stop-and-length", vals)
except KeyError:
if fallback:
self._fetch_traj = False
return self._get_stop_and_length(storage, fallback=False)
raise
else:
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
done = storage._storage.get(self._used_end_key)
except KeyError:
done = storage._storage.get(("_data", self.end_key))
# cache that value for future use
self._used_end_key = ("_data", self.end_key)
self._uses_data_prefix = (
isinstance(self._used_end_key, tuple)
and self._used_end_key[0] == "_data"
)
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
return self._cache.setdefault("stop-and-length", vals)
except KeyError:
if fallback:
self._fetch_traj = True
return self._get_stop_and_length(storage, fallback=False)
raise
def _adjusted_batch_size(self, batch_size):
if self.num_slices is not None:
if batch_size % self.num_slices != 0:
raise RuntimeError(
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
)
seq_length = batch_size // self.num_slices
num_slices = self.num_slices
else:
if batch_size % self.slice_len != 0:
raise RuntimeError(
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
)
seq_length = self.slice_len
num_slices = batch_size // self.slice_len
return seq_length, num_slices
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if not isinstance(storage, TensorStorage):
raise RuntimeError(
f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
)
# pick up as many trajs as we need
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
return self._sample_slices(lengths, start_idx, seq_length, num_slices)
def _sample_slices(
self, lengths, start_idx, seq_length, num_slices, traj_idx=None
) -> Tuple[torch.Tensor, dict]:
if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for. "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
# make seq_length a tensor with values clamped by lengths
seq_length = lengths.clamp_max(seq_length)
if traj_idx is None:
traj_idx = torch.randint(
lengths.shape[0], (num_slices,), device=lengths.device
)
else:
num_slices = traj_idx.shape[0]
relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
* (lengths[traj_idx] - seq_length)
)
.floor()
.to(start_idx.dtype)
)
starts = start_idx[traj_idx] + relative_starts
index = self._tensor_slices_from_startend(seq_length, starts)
if self.truncated_key is not None:
truncated_key = self.truncated_key
truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
return index.to(torch.long), {truncated_key: truncated}
return index.to(torch.long), {}
@property
def _used_traj_key(self):
return self.__dict__.get("__used_traj_key", self.traj_key)
@_used_traj_key.setter
def _used_traj_key(self, value):
self.__dict__["__used_traj_key"] = value
@property
def _used_end_key(self):
return self.__dict__.get("__used_end_key", self.end_key)
@_used_end_key.setter
def _used_end_key(self, value):
self.__dict__["__used_end_key"] = value
def _empty(self):
pass
def dumps(self, path):
# no op - cache does not need to be saved
...
def loads(self, path):
# no op
...
def __getstate__(self):
state = copy(self.__dict__)
state["_cache"] = {}
return state

View File

@@ -60,11 +60,11 @@ dropout: 0.01
simnorm_dim: 8 simnorm_dim: 8
# logging # logging
wandb_project: ??? wandb_project: tdmpcv2
wandb_entity: ??? wandb_entity: nicklashansen
wandb_silent: false wandb_silent: false
disable_wandb: true disable_wandb: false
save_csv: true save_csv: false
# misc # misc
save_video: true save_video: true

View File

@@ -6,9 +6,9 @@ import gym
from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.tensor import TensorWrapper from envs.wrappers.tensor import TensorWrapper
from envs.dmcontrol import make_env as make_dm_control_env from envs.dmcontrol import make_env as make_dm_control_env
from envs.maniskill import make_env as make_maniskill_env # from envs.maniskill import make_env as make_maniskill_env
from envs.metaworld import make_env as make_metaworld_env # from envs.metaworld import make_env as make_metaworld_env
from envs.myosuite import make_env as make_myosuite_env # from envs.myosuite import make_env as make_myosuite_env
from envs.exceptions import UnknownTaskError from envs.exceptions import UnknownTaskError
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -44,7 +44,7 @@ def make_env(cfg):
env = make_multitask_env(cfg) env = make_multitask_env(cfg)
else: else:
env = None env = None
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: for fn in [make_dm_control_env]: #, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try: try:
env = fn(cfg) env = fn(cfg)
except UnknownTaskError: except UnknownTaskError:

View File

@@ -14,6 +14,7 @@ class OnlineTrainer(Trainer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._step = 0 self._step = 0
self._ep_idx = 0 self._ep_idx = 0
self._ep_reward = 0
self._start_time = time() self._start_time = time()
def common_metrics(self): def common_metrics(self):
@@ -21,6 +22,7 @@ class OnlineTrainer(Trainer):
return dict( return dict(
step=self._step, step=self._step,
episode=self._ep_idx, episode=self._ep_idx,
episode_reward=self._ep_reward,
total_time=time() - self._start_time, total_time=time() - self._start_time,
) )
@@ -47,20 +49,24 @@ class OnlineTrainer(Trainer):
episode_success=np.nanmean(ep_successes), episode_success=np.nanmean(ep_successes),
) )
def to_td(self, obs, action=None, reward=None): def to_td(self, obs, action=None, reward=None, done=None):
"""Creates a TensorDict for a new episode.""" """Creates a TensorDict for a new episode."""
if isinstance(obs, dict): if isinstance(obs, dict):
obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu() obs = TensorDict({k: v for k,v in obs.items()}, batch_size=()).cpu()
else: else:
obs = obs.unsqueeze(0).cpu() obs = obs.cpu()
if action is None: if action is None:
action = torch.empty_like(self.env.rand_act()) action = torch.empty_like(self.env.rand_act())
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan'))
if done is None:
done = False
done = torch.tensor(done)
td = TensorDict(dict( td = TensorDict(dict(
obs=obs, obs=obs.unsqueeze(0),
action=action.unsqueeze(0), action=action.unsqueeze(0),
reward=reward.unsqueeze(0), reward=reward.unsqueeze(0),
done=done.unsqueeze(0),
), batch_size=(1,)) ), batch_size=(1,))
return td return td
@@ -83,15 +89,16 @@ class OnlineTrainer(Trainer):
if self._step > 0: if self._step > 0:
train_metrics.update( train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'], episode_success=info['success'],
) )
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx += 1
self.buffer.add(torch.cat(self._tds))
obs = self.env.reset() obs = self.env.reset()
self._tds = [self.to_td(obs)] self._tds = [self.to_td(obs)]
self._ep_reward = 0
# Collect experience # Collect experience
if self._step > self.cfg.seed_steps: if self._step > self.cfg.seed_steps:
@@ -99,7 +106,8 @@ class OnlineTrainer(Trainer):
else: else:
action = self.env.rand_act() action = self.env.rand_act()
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward)) self._tds.append(self.to_td(obs, action, reward, done))
self._ep_reward += reward
# Update agent # Update agent
if self._step >= self.cfg.seed_steps: if self._step >= self.cfg.seed_steps: