migrate to slicebuffer from torchrl-nightly

This commit is contained in:
Nicklas Hansen
2024-01-10 19:53:30 -08:00
parent ff02f41e73
commit ab43880945
2 changed files with 1 additions and 353 deletions

View File

@@ -1,8 +1,7 @@
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler
from common.samplers import SliceSampler
class Buffer(): class Buffer():

View File

@@ -1,351 +0,0 @@
from __future__ import annotations
from copy import copy
from typing import Tuple
import torch
from tensordict.utils import NestedKey
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.samplers import Sampler
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
# This copy will live here until it has been included in a few torchrl stable releases
class SliceSampler(Sampler):
"""Samples slices of data along the first dimension, given start and stop signals.
This class samples sub-trajectories with replacement. For a version without
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
with ``slice_len``.
slice_len (int): the length of the slices to be sampled. The batch-size
must be greater or equal to the ``slice_len`` argument and divisible
by it. Exclusive with ``num_slices``.
end_key (NestedKey, optional): the key indicating the end of a
trajectory (or episode). Defaults to ``("next", "done")``.
traj_key (NestedKey, optional): the key indicating the trajectories.
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
trajectory breaks. Defaults to ``("next", "truncated")``.
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
instances (otherwise the truncated key is returned in the info dictionary
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
attempt to find the ``traj_key`` entry in the storage. If it cannot be
found, the ``end_key`` will be used to reconstruct the episodes.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
... storage=LazyMemmapStorage(1_000_000),
... sampler=SliceSampler(cache_values=True, num_slices=10),
... batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
... {
... "episode": episode,
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
... "act": torch.randn((20,)).expand(1000, 20),
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
... }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
most of TorchRL's datasets:
Examples:
>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
... batch = batch.reshape(num_slices, -1)
... break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
[14],
[ 8],
[10],
[13],
[ 4],
[ 2],
[ 3],
[22],
[ 8]])
"""
def __init__(
self,
*,
num_slices: int = None,
slice_len: int = None,
end_key: NestedKey | None = None,
traj_key: NestedKey | None = None,
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
) -> object:
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
traj_key = "episode"
if not ((num_slices is None) ^ (slice_len is None)):
raise TypeError(
"Either num_slices or slice_len must be not None, and not both. "
f"Got num_slices={num_slices} and slice_len={slice_len}."
)
self.num_slices = num_slices
self.slice_len = slice_len
self.end_key = end_key
self.traj_key = traj_key
self.truncated_key = truncated_key
self.cache_values = cache_values
self._fetch_traj = True
self._uses_data_prefix = False
self.strict_length = strict_length
self._cache = {}
@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
# stop_idx = stop_idx.cumsum(0) - 1
# even slower
# t = trajectory.unsqueeze(0)
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
# stop_idx = torch.conv1d(t, w).nonzero()
# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
if end.ndim != 1:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
)
stop_idx = end.view(-1).nonzero().view(-1)
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
lengths = stop_idx - start_idx + 1
return start_idx, stop_idx, lengths
def _tensor_slices_from_startend(self, seq_length, start):
if isinstance(seq_length, int):
return (
torch.arange(
seq_length, device=start.device, dtype=start.dtype
).unsqueeze(0)
+ start.unsqueeze(1)
).view(-1)
else:
# when padding is needed
return torch.cat(
[
_start
+ torch.arange(_seq_len, device=start.device, dtype=start.dtype)
for _start, _seq_len in zip(start, seq_length)
]
)
def _get_stop_and_length(self, storage, fallback=True):
if self.cache_values and "stop-and-length" in self._cache:
return self._cache.get("stop-and-length")
if self._fetch_traj:
# We first try with the traj_key
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
trajectory = storage._storage.get(self._used_traj_key)
except KeyError:
trajectory = storage._storage.get(("_data", self.traj_key))
# cache that value for future use
self._used_traj_key = ("_data", self.traj_key)
self._uses_data_prefix = (
isinstance(self._used_traj_key, tuple)
and self._used_traj_key[0] == "_data"
)
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = False
return self._get_stop_and_length(storage, fallback=False)
raise
else:
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
done = storage._storage.get(self._used_end_key)
except KeyError:
done = storage._storage.get(("_data", self.end_key))
# cache that value for future use
self._used_end_key = ("_data", self.end_key)
self._uses_data_prefix = (
isinstance(self._used_end_key, tuple)
and self._used_end_key[0] == "_data"
)
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = True
return self._get_stop_and_length(storage, fallback=False)
raise
def _adjusted_batch_size(self, batch_size):
if self.num_slices is not None:
if batch_size % self.num_slices != 0:
raise RuntimeError(
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
)
seq_length = batch_size // self.num_slices
num_slices = self.num_slices
else:
if batch_size % self.slice_len != 0:
raise RuntimeError(
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
)
seq_length = self.slice_len
num_slices = batch_size // self.slice_len
return seq_length, num_slices
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if not isinstance(storage, TensorStorage):
raise RuntimeError(
f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
)
# pick up as many trajs as we need
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
return self._sample_slices(lengths, start_idx, seq_length, num_slices)
def _sample_slices(
self, lengths, start_idx, seq_length, num_slices, traj_idx=None
) -> Tuple[torch.Tensor, dict]:
if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for. "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
# make seq_length a tensor with values clamped by lengths
seq_length = lengths.clamp_max(seq_length)
if traj_idx is None:
traj_idx = torch.randint(
lengths.shape[0], (num_slices,), device=lengths.device
)
else:
num_slices = traj_idx.shape[0]
relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
* (lengths[traj_idx] - seq_length + 1)
)
.floor()
.to(start_idx.dtype)
)
starts = start_idx[traj_idx] + relative_starts
index = self._tensor_slices_from_startend(seq_length, starts)
if self.truncated_key is not None:
truncated_key = self.truncated_key
truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
return index.to(torch.long), {truncated_key: truncated}
return index.to(torch.long), {}
@property
def _used_traj_key(self):
return self.__dict__.get("__used_traj_key", self.traj_key)
@_used_traj_key.setter
def _used_traj_key(self, value):
self.__dict__["__used_traj_key"] = value
@property
def _used_end_key(self):
return self.__dict__.get("__used_end_key", self.end_key)
@_used_end_key.setter
def _used_end_key(self, value):
self.__dict__["__used_end_key"] = value
def _empty(self):
pass
def dumps(self, path):
# no op - cache does not need to be saved
...
def loads(self, path):
# no op
...
def __getstate__(self):
state = copy(self.__dict__)
state["_cache"] = {}
return state