faster replay buffer implementation
This commit is contained in:
@@ -2,30 +2,10 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from tensordict.tensordict import TensorDict
|
from tensordict.tensordict import TensorDict
|
||||||
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
||||||
from torchrl.data.replay_buffers.samplers import RandomSampler
|
|
||||||
from torchrl.envs import RandomCropTensorDict, Transform, Compose
|
from torchrl.envs import RandomCropTensorDict, Transform, Compose
|
||||||
|
|
||||||
from common.logger import make_dir
|
from common.logger import make_dir
|
||||||
|
from common.samplers import SliceSampler
|
||||||
|
|
||||||
class DataPrepTransform(Transform):
|
|
||||||
"""
|
|
||||||
Preprocesses data for TD-MPC2 training.
|
|
||||||
Replay data is expected to be a TensorDict with the following keys:
|
|
||||||
obs: observations
|
|
||||||
action: actions
|
|
||||||
reward: rewards
|
|
||||||
task: task IDs (optional)
|
|
||||||
A TensorDict with T time steps has T+1 observations and T actions and rewards.
|
|
||||||
The first actions and rewards in each TensorDict are dummies and should be ignored.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__([])
|
|
||||||
|
|
||||||
def forward(self, td):
|
|
||||||
td = td.permute(1,0)
|
|
||||||
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
|
|
||||||
|
|
||||||
|
|
||||||
class Buffer():
|
class Buffer():
|
||||||
@@ -37,7 +17,9 @@ class Buffer():
|
|||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._device = torch.device('cuda')
|
self._device = torch.device('cuda')
|
||||||
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
|
self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1)
|
||||||
|
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||||
|
self._num_steps = 0
|
||||||
self._num_eps = 0
|
self._num_eps = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -45,6 +27,11 @@ class Buffer():
|
|||||||
"""Return the capacity of the buffer."""
|
"""Return the capacity of the buffer."""
|
||||||
return self._capacity
|
return self._capacity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_steps(self):
|
||||||
|
"""Return the number of steps in the buffer."""
|
||||||
|
return self._num_steps
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_eps(self):
|
def num_eps(self):
|
||||||
"""Return the number of episodes in the buffer."""
|
"""Return the number of episodes in the buffer."""
|
||||||
@@ -53,32 +40,25 @@ class Buffer():
|
|||||||
def _reserve_buffer(self, storage):
|
def _reserve_buffer(self, storage):
|
||||||
"""
|
"""
|
||||||
Reserve a buffer with the given storage.
|
Reserve a buffer with the given storage.
|
||||||
Uses the RandomSampler to sample trajectories,
|
|
||||||
and the RandomCropTensorDict transform to crop trajectories to the desired length.
|
|
||||||
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
|
|
||||||
"""
|
"""
|
||||||
return ReplayBuffer(
|
return ReplayBuffer(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
sampler=RandomSampler(),
|
sampler=SliceSampler(
|
||||||
pin_memory=True,
|
slice_len=self.cfg.horizon+1,
|
||||||
prefetch=1,
|
end_key='done',
|
||||||
transform=Compose(
|
truncated_key=None,
|
||||||
RandomCropTensorDict(self.cfg.horizon+1, -1),
|
|
||||||
DataPrepTransform(),
|
|
||||||
),
|
),
|
||||||
|
pin_memory=True,
|
||||||
|
prefetch=2,
|
||||||
batch_size=self.cfg.batch_size,
|
batch_size=self.cfg.batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init(self, tds):
|
def _init(self, td):
|
||||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||||
mem_free, _ = torch.cuda.mem_get_info()
|
mem_free, _ = torch.cuda.mem_get_info()
|
||||||
bytes_per_ep = sum([
|
bytes_per_step = sum([x.numel()*x.element_size() for x in td[0].values()])
|
||||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
print(f'Bytes per step: {bytes_per_step:,}')
|
||||||
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
total_bytes = bytes_per_step*self._capacity
|
||||||
for k,v in tds.items()
|
|
||||||
])
|
|
||||||
print(f'Bytes per episode: {bytes_per_ep:,}')
|
|
||||||
total_bytes = bytes_per_ep*self._capacity
|
|
||||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||||
# Heuristic: decide whether to use CUDA or CPU memory
|
# Heuristic: decide whether to use CUDA or CPU memory
|
||||||
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
|
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
|
||||||
@@ -92,22 +72,29 @@ class Buffer():
|
|||||||
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
|
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
|
||||||
)
|
)
|
||||||
|
|
||||||
def add(self, tds):
|
def add(self, td):
|
||||||
"""Add an episode to the buffer. All episodes are expected to have the same length."""
|
"""Add a step to the buffer."""
|
||||||
if self._num_eps == 0:
|
done = bool(td['done'].any())
|
||||||
self._buffer = self._init(tds)
|
if done:
|
||||||
self._buffer.add(tds)
|
self._num_eps +=1
|
||||||
self._num_eps += 1
|
td['episode'] = torch.ones_like(td['done']) * self._num_eps
|
||||||
return self._num_eps
|
td['step'] = torch.arange(0, len(td))
|
||||||
|
if self._num_steps == 0:
|
||||||
|
self._buffer = self._init(td)
|
||||||
|
self._buffer.extend(td)
|
||||||
|
self._num_steps += 1
|
||||||
|
return self._num_steps
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
"""Sample a batch of sub-trajectories from the buffer."""
|
"""Sample a batch of sub-trajectories from the buffer."""
|
||||||
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
|
td = self._buffer.sample(batch_size=self._batch_size) \
|
||||||
return obs.to(self._device, non_blocking=True), \
|
.reshape(-1, self.cfg.horizon+1).permute(1, 0)
|
||||||
action.to(self._device, non_blocking=True), \
|
obs = td['obs'].to(self._device, non_blocking=True)
|
||||||
reward.to(self._device, non_blocking=True), \
|
action = td['action'][1:].to(self._device, non_blocking=True)
|
||||||
task.to(self._device, non_blocking=True) if task is not None else None
|
reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True)
|
||||||
|
task = td['task'][0].to(self._device, non_blocking=True) if 'task' in td.keys() else None
|
||||||
|
return obs, action, reward, task
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
"""Save the buffer to disk. Useful for storing offline datasets."""
|
"""Save the buffer to disk. Useful for storing offline datasets."""
|
||||||
td = self._buffer._storage._storage.cpu()
|
td = self._buffer._storage._storage.cpu()
|
||||||
|
|||||||
365
tdmpc2/common/samplers.py
Normal file
365
tdmpc2/common/samplers.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import warnings
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from copy import copy, deepcopy
|
||||||
|
from multiprocessing.context import get_spawning_popen
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tensordict import MemoryMappedTensor
|
||||||
|
from tensordict.utils import NestedKey
|
||||||
|
|
||||||
|
from torchrl._extension import EXTENSION_WARNING
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torchrl._torchrl import (
|
||||||
|
MinSegmentTreeFp32,
|
||||||
|
MinSegmentTreeFp64,
|
||||||
|
SumSegmentTreeFp32,
|
||||||
|
SumSegmentTreeFp64,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
warnings.warn(EXTENSION_WARNING)
|
||||||
|
|
||||||
|
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
|
||||||
|
from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES
|
||||||
|
from torchrl.data.replay_buffers.samplers import Sampler
|
||||||
|
|
||||||
|
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
|
||||||
|
|
||||||
|
|
||||||
|
class SliceSampler(Sampler):
|
||||||
|
"""Samples slices of data along the first dimension, given start and stop signals.
|
||||||
|
|
||||||
|
This class samples sub-trajectories with replacement. For a version without
|
||||||
|
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
num_slices (int): the number of slices to be sampled. The batch-size
|
||||||
|
must be greater or equal to the ``num_slices`` argument. Exclusive
|
||||||
|
with ``slice_len``.
|
||||||
|
slice_len (int): the length of the slices to be sampled. The batch-size
|
||||||
|
must be greater or equal to the ``slice_len`` argument and divisible
|
||||||
|
by it. Exclusive with ``num_slices``.
|
||||||
|
end_key (NestedKey, optional): the key indicating the end of a
|
||||||
|
trajectory (or episode). Defaults to ``("next", "done")``.
|
||||||
|
traj_key (NestedKey, optional): the key indicating the trajectories.
|
||||||
|
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
|
||||||
|
cache_values (bool, optional): to be used with static datasets.
|
||||||
|
Will cache the start and end signal of the trajectory.
|
||||||
|
truncated_key (NestedKey, optional): If not ``None``, this argument
|
||||||
|
indicates where a truncated signal should be written in the output
|
||||||
|
data. This is used to indicate to value estimators where the provided
|
||||||
|
trajectory breaks. Defaults to ``("next", "truncated")``.
|
||||||
|
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
|
||||||
|
instances (otherwise the truncated key is returned in the info dictionary
|
||||||
|
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
|
||||||
|
strict_length (bool, optional): if ``False``, trajectories of length
|
||||||
|
shorter than `slice_len` (or `batch_size // num_slices`) will be
|
||||||
|
allowed to appear in the batch.
|
||||||
|
Be mindful that this can result in effective `batch_size` shorter
|
||||||
|
than the one asked for! Trajectories can be split using
|
||||||
|
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
|
||||||
|
|
||||||
|
.. note:: To recover the trajectory splits in the storage,
|
||||||
|
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
|
||||||
|
attempt to find the ``traj_key`` entry in the storage. If it cannot be
|
||||||
|
found, the ``end_key`` will be used to reconstruct the episodes.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import torch
|
||||||
|
>>> from tensordict import TensorDict
|
||||||
|
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
|
||||||
|
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||||
|
>>> torch.manual_seed(0)
|
||||||
|
>>> rb = TensorDictReplayBuffer(
|
||||||
|
... storage=LazyMemmapStorage(1_000_000),
|
||||||
|
... sampler=SliceSampler(cache_values=True, num_slices=10),
|
||||||
|
... batch_size=320,
|
||||||
|
... )
|
||||||
|
>>> episode = torch.zeros(1000, dtype=torch.int)
|
||||||
|
>>> episode[:300] = 1
|
||||||
|
>>> episode[300:550] = 2
|
||||||
|
>>> episode[550:700] = 3
|
||||||
|
>>> episode[700:] = 4
|
||||||
|
>>> data = TensorDict(
|
||||||
|
... {
|
||||||
|
... "episode": episode,
|
||||||
|
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
|
||||||
|
... "act": torch.randn((20,)).expand(1000, 20),
|
||||||
|
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
|
||||||
|
... }, [1000]
|
||||||
|
... )
|
||||||
|
>>> rb.extend(data)
|
||||||
|
>>> sample = rb.sample()
|
||||||
|
>>> print("sample:", sample)
|
||||||
|
>>> print("episodes", sample.get("episode").unique())
|
||||||
|
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
|
||||||
|
|
||||||
|
:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
|
||||||
|
most of TorchRL's datasets:
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import torch
|
||||||
|
>>>
|
||||||
|
>>> from torchrl.data.datasets import RobosetExperienceReplay
|
||||||
|
>>> from torchrl.data import SliceSampler
|
||||||
|
>>>
|
||||||
|
>>> torch.manual_seed(0)
|
||||||
|
>>> num_slices = 10
|
||||||
|
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
|
||||||
|
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
|
||||||
|
>>> for batch in data:
|
||||||
|
... batch = batch.reshape(num_slices, -1)
|
||||||
|
... break
|
||||||
|
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
|
||||||
|
check that each batch only has one episode: tensor([[19],
|
||||||
|
[14],
|
||||||
|
[ 8],
|
||||||
|
[10],
|
||||||
|
[13],
|
||||||
|
[ 4],
|
||||||
|
[ 2],
|
||||||
|
[ 3],
|
||||||
|
[22],
|
||||||
|
[ 8]])
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
num_slices: int = None,
|
||||||
|
slice_len: int = None,
|
||||||
|
end_key: NestedKey | None = None,
|
||||||
|
traj_key: NestedKey | None = None,
|
||||||
|
cache_values: bool = False,
|
||||||
|
truncated_key: NestedKey | None = ("next", "truncated"),
|
||||||
|
strict_length: bool = True,
|
||||||
|
) -> object:
|
||||||
|
if end_key is None:
|
||||||
|
end_key = ("next", "done")
|
||||||
|
if traj_key is None:
|
||||||
|
traj_key = "episode"
|
||||||
|
if not ((num_slices is None) ^ (slice_len is None)):
|
||||||
|
raise TypeError(
|
||||||
|
"Either num_slices or slice_len must be not None, and not both. "
|
||||||
|
f"Got num_slices={num_slices} and slice_len={slice_len}."
|
||||||
|
)
|
||||||
|
self.num_slices = num_slices
|
||||||
|
self.slice_len = slice_len
|
||||||
|
self.end_key = end_key
|
||||||
|
self.traj_key = traj_key
|
||||||
|
self.truncated_key = truncated_key
|
||||||
|
self.cache_values = cache_values
|
||||||
|
self._fetch_traj = True
|
||||||
|
self._uses_data_prefix = False
|
||||||
|
self.strict_length = strict_length
|
||||||
|
self._cache = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_start_stop_traj(*, trajectory=None, end=None):
|
||||||
|
if trajectory is not None:
|
||||||
|
# slower
|
||||||
|
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
|
||||||
|
# stop_idx = stop_idx.cumsum(0) - 1
|
||||||
|
|
||||||
|
# even slower
|
||||||
|
# t = trajectory.unsqueeze(0)
|
||||||
|
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
|
||||||
|
# stop_idx = torch.conv1d(t, w).nonzero()
|
||||||
|
|
||||||
|
# faster
|
||||||
|
end = trajectory[:-1] != trajectory[1:]
|
||||||
|
end = torch.cat([end, torch.ones_like(end[:1])], 0)
|
||||||
|
else:
|
||||||
|
end = torch.index_fill(
|
||||||
|
end,
|
||||||
|
index=torch.tensor(-1, device=end.device, dtype=torch.long),
|
||||||
|
dim=0,
|
||||||
|
value=1,
|
||||||
|
)
|
||||||
|
if end.ndim != 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
|
||||||
|
)
|
||||||
|
stop_idx = end.view(-1).nonzero().view(-1)
|
||||||
|
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
|
||||||
|
lengths = stop_idx - start_idx + 1
|
||||||
|
return start_idx, stop_idx, lengths
|
||||||
|
|
||||||
|
def _tensor_slices_from_startend(self, seq_length, start):
|
||||||
|
if isinstance(seq_length, int):
|
||||||
|
return (
|
||||||
|
torch.arange(
|
||||||
|
seq_length, device=start.device, dtype=start.dtype
|
||||||
|
).unsqueeze(0)
|
||||||
|
+ start.unsqueeze(1)
|
||||||
|
).view(-1)
|
||||||
|
else:
|
||||||
|
# when padding is needed
|
||||||
|
return torch.cat(
|
||||||
|
[
|
||||||
|
_start
|
||||||
|
+ torch.arange(_seq_len, device=start.device, dtype=start.dtype)
|
||||||
|
for _start, _seq_len in zip(start, seq_length)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_stop_and_length(self, storage, fallback=True):
|
||||||
|
if self.cache_values and "stop-and-length" in self._cache:
|
||||||
|
return self._cache.get("stop-and-length")
|
||||||
|
|
||||||
|
if self._fetch_traj:
|
||||||
|
# We first try with the traj_key
|
||||||
|
try:
|
||||||
|
# In some cases, the storage hides the data behind "_data".
|
||||||
|
# In the future, this may be deprecated, and we don't want to mess
|
||||||
|
# with the keys provided by the user so we fall back on a proxy to
|
||||||
|
# the traj key.
|
||||||
|
try:
|
||||||
|
trajectory = storage._storage.get(self._used_traj_key)
|
||||||
|
except KeyError:
|
||||||
|
trajectory = storage._storage.get(("_data", self.traj_key))
|
||||||
|
# cache that value for future use
|
||||||
|
self._used_traj_key = ("_data", self.traj_key)
|
||||||
|
self._uses_data_prefix = (
|
||||||
|
isinstance(self._used_traj_key, tuple)
|
||||||
|
and self._used_traj_key[0] == "_data"
|
||||||
|
)
|
||||||
|
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
|
||||||
|
return self._cache.setdefault("stop-and-length", vals)
|
||||||
|
except KeyError:
|
||||||
|
if fallback:
|
||||||
|
self._fetch_traj = False
|
||||||
|
return self._get_stop_and_length(storage, fallback=False)
|
||||||
|
raise
|
||||||
|
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# In some cases, the storage hides the data behind "_data".
|
||||||
|
# In the future, this may be deprecated, and we don't want to mess
|
||||||
|
# with the keys provided by the user so we fall back on a proxy to
|
||||||
|
# the traj key.
|
||||||
|
try:
|
||||||
|
done = storage._storage.get(self._used_end_key)
|
||||||
|
except KeyError:
|
||||||
|
done = storage._storage.get(("_data", self.end_key))
|
||||||
|
# cache that value for future use
|
||||||
|
self._used_end_key = ("_data", self.end_key)
|
||||||
|
self._uses_data_prefix = (
|
||||||
|
isinstance(self._used_end_key, tuple)
|
||||||
|
and self._used_end_key[0] == "_data"
|
||||||
|
)
|
||||||
|
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
|
||||||
|
return self._cache.setdefault("stop-and-length", vals)
|
||||||
|
except KeyError:
|
||||||
|
if fallback:
|
||||||
|
self._fetch_traj = True
|
||||||
|
return self._get_stop_and_length(storage, fallback=False)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _adjusted_batch_size(self, batch_size):
|
||||||
|
if self.num_slices is not None:
|
||||||
|
if batch_size % self.num_slices != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
|
||||||
|
)
|
||||||
|
seq_length = batch_size // self.num_slices
|
||||||
|
num_slices = self.num_slices
|
||||||
|
else:
|
||||||
|
if batch_size % self.slice_len != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
|
||||||
|
)
|
||||||
|
seq_length = self.slice_len
|
||||||
|
num_slices = batch_size // self.slice_len
|
||||||
|
return seq_length, num_slices
|
||||||
|
|
||||||
|
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
|
||||||
|
if not isinstance(storage, TensorStorage):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# pick up as many trajs as we need
|
||||||
|
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
|
||||||
|
seq_length, num_slices = self._adjusted_batch_size(batch_size)
|
||||||
|
return self._sample_slices(lengths, start_idx, seq_length, num_slices)
|
||||||
|
|
||||||
|
def _sample_slices(
|
||||||
|
self, lengths, start_idx, seq_length, num_slices, traj_idx=None
|
||||||
|
) -> Tuple[torch.Tensor, dict]:
|
||||||
|
if (lengths < seq_length).any():
|
||||||
|
if self.strict_length:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Some stored trajectories have a length shorter than the slice that was asked for. "
|
||||||
|
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
|
||||||
|
"in you batch."
|
||||||
|
)
|
||||||
|
# make seq_length a tensor with values clamped by lengths
|
||||||
|
seq_length = lengths.clamp_max(seq_length)
|
||||||
|
|
||||||
|
if traj_idx is None:
|
||||||
|
traj_idx = torch.randint(
|
||||||
|
lengths.shape[0], (num_slices,), device=lengths.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
num_slices = traj_idx.shape[0]
|
||||||
|
relative_starts = (
|
||||||
|
(
|
||||||
|
torch.rand(num_slices, device=lengths.device)
|
||||||
|
* (lengths[traj_idx] - seq_length)
|
||||||
|
)
|
||||||
|
.floor()
|
||||||
|
.to(start_idx.dtype)
|
||||||
|
)
|
||||||
|
starts = start_idx[traj_idx] + relative_starts
|
||||||
|
index = self._tensor_slices_from_startend(seq_length, starts)
|
||||||
|
if self.truncated_key is not None:
|
||||||
|
truncated_key = self.truncated_key
|
||||||
|
|
||||||
|
truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
|
||||||
|
if isinstance(seq_length, int):
|
||||||
|
truncated.view(num_slices, -1)[:, -1] = 1
|
||||||
|
else:
|
||||||
|
truncated[seq_length.cumsum(0) - 1] = 1
|
||||||
|
return index.to(torch.long), {truncated_key: truncated}
|
||||||
|
return index.to(torch.long), {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _used_traj_key(self):
|
||||||
|
return self.__dict__.get("__used_traj_key", self.traj_key)
|
||||||
|
|
||||||
|
@_used_traj_key.setter
|
||||||
|
def _used_traj_key(self, value):
|
||||||
|
self.__dict__["__used_traj_key"] = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _used_end_key(self):
|
||||||
|
return self.__dict__.get("__used_end_key", self.end_key)
|
||||||
|
|
||||||
|
@_used_end_key.setter
|
||||||
|
def _used_end_key(self, value):
|
||||||
|
self.__dict__["__used_end_key"] = value
|
||||||
|
|
||||||
|
def _empty(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def dumps(self, path):
|
||||||
|
# no op - cache does not need to be saved
|
||||||
|
...
|
||||||
|
|
||||||
|
def loads(self, path):
|
||||||
|
# no op
|
||||||
|
...
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = copy(self.__dict__)
|
||||||
|
state["_cache"] = {}
|
||||||
|
return state
|
||||||
@@ -60,11 +60,11 @@ dropout: 0.01
|
|||||||
simnorm_dim: 8
|
simnorm_dim: 8
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
wandb_project: ???
|
wandb_project: tdmpcv2
|
||||||
wandb_entity: ???
|
wandb_entity: nicklashansen
|
||||||
wandb_silent: false
|
wandb_silent: false
|
||||||
disable_wandb: true
|
disable_wandb: false
|
||||||
save_csv: true
|
save_csv: false
|
||||||
|
|
||||||
# misc
|
# misc
|
||||||
save_video: true
|
save_video: true
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import gym
|
|||||||
from envs.wrappers.multitask import MultitaskWrapper
|
from envs.wrappers.multitask import MultitaskWrapper
|
||||||
from envs.wrappers.tensor import TensorWrapper
|
from envs.wrappers.tensor import TensorWrapper
|
||||||
from envs.dmcontrol import make_env as make_dm_control_env
|
from envs.dmcontrol import make_env as make_dm_control_env
|
||||||
from envs.maniskill import make_env as make_maniskill_env
|
# from envs.maniskill import make_env as make_maniskill_env
|
||||||
from envs.metaworld import make_env as make_metaworld_env
|
# from envs.metaworld import make_env as make_metaworld_env
|
||||||
from envs.myosuite import make_env as make_myosuite_env
|
# from envs.myosuite import make_env as make_myosuite_env
|
||||||
from envs.exceptions import UnknownTaskError
|
from envs.exceptions import UnknownTaskError
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
@@ -44,7 +44,7 @@ def make_env(cfg):
|
|||||||
env = make_multitask_env(cfg)
|
env = make_multitask_env(cfg)
|
||||||
else:
|
else:
|
||||||
env = None
|
env = None
|
||||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
for fn in [make_dm_control_env]: #, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
||||||
try:
|
try:
|
||||||
env = fn(cfg)
|
env = fn(cfg)
|
||||||
except UnknownTaskError:
|
except UnknownTaskError:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class OnlineTrainer(Trainer):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._step = 0
|
self._step = 0
|
||||||
self._ep_idx = 0
|
self._ep_idx = 0
|
||||||
|
self._ep_reward = 0
|
||||||
self._start_time = time()
|
self._start_time = time()
|
||||||
|
|
||||||
def common_metrics(self):
|
def common_metrics(self):
|
||||||
@@ -21,6 +22,7 @@ class OnlineTrainer(Trainer):
|
|||||||
return dict(
|
return dict(
|
||||||
step=self._step,
|
step=self._step,
|
||||||
episode=self._ep_idx,
|
episode=self._ep_idx,
|
||||||
|
episode_reward=self._ep_reward,
|
||||||
total_time=time() - self._start_time,
|
total_time=time() - self._start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,22 +49,26 @@ class OnlineTrainer(Trainer):
|
|||||||
episode_success=np.nanmean(ep_successes),
|
episode_success=np.nanmean(ep_successes),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_td(self, obs, action=None, reward=None):
|
def to_td(self, obs, action=None, reward=None, done=None):
|
||||||
"""Creates a TensorDict for a new episode."""
|
"""Creates a TensorDict for a new episode."""
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu()
|
obs = TensorDict({k: v for k,v in obs.items()}, batch_size=()).cpu()
|
||||||
else:
|
else:
|
||||||
obs = obs.unsqueeze(0).cpu()
|
obs = obs.cpu()
|
||||||
if action is None:
|
if action is None:
|
||||||
action = torch.empty_like(self.env.rand_act())
|
action = torch.empty_like(self.env.rand_act())
|
||||||
if reward is None:
|
if reward is None:
|
||||||
reward = torch.tensor(float('nan'))
|
reward = torch.tensor(float('nan'))
|
||||||
|
if done is None:
|
||||||
|
done = False
|
||||||
|
done = torch.tensor(done)
|
||||||
td = TensorDict(dict(
|
td = TensorDict(dict(
|
||||||
obs=obs,
|
obs=obs.unsqueeze(0),
|
||||||
action=action.unsqueeze(0),
|
action=action.unsqueeze(0),
|
||||||
reward=reward.unsqueeze(0),
|
reward=reward.unsqueeze(0),
|
||||||
|
done=done.unsqueeze(0),
|
||||||
), batch_size=(1,))
|
), batch_size=(1,))
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Train a TD-MPC2 agent."""
|
"""Train a TD-MPC2 agent."""
|
||||||
@@ -83,15 +89,16 @@ class OnlineTrainer(Trainer):
|
|||||||
|
|
||||||
if self._step > 0:
|
if self._step > 0:
|
||||||
train_metrics.update(
|
train_metrics.update(
|
||||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
|
||||||
episode_success=info['success'],
|
episode_success=info['success'],
|
||||||
)
|
)
|
||||||
train_metrics.update(self.common_metrics())
|
train_metrics.update(self.common_metrics())
|
||||||
self.logger.log(train_metrics, 'train')
|
self.logger.log(train_metrics, 'train')
|
||||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
self._ep_idx += 1
|
||||||
|
self.buffer.add(torch.cat(self._tds))
|
||||||
|
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
self._tds = [self.to_td(obs)]
|
self._tds = [self.to_td(obs)]
|
||||||
|
self._ep_reward = 0
|
||||||
|
|
||||||
# Collect experience
|
# Collect experience
|
||||||
if self._step > self.cfg.seed_steps:
|
if self._step > self.cfg.seed_steps:
|
||||||
@@ -99,7 +106,8 @@ class OnlineTrainer(Trainer):
|
|||||||
else:
|
else:
|
||||||
action = self.env.rand_act()
|
action = self.env.rand_act()
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
self._tds.append(self.to_td(obs, action, reward))
|
self._tds.append(self.to_td(obs, action, reward, done))
|
||||||
|
self._ep_reward += reward
|
||||||
|
|
||||||
# Update agent
|
# Update agent
|
||||||
if self._step >= self.cfg.seed_steps:
|
if self._step >= self.cfg.seed_steps:
|
||||||
|
|||||||
Reference in New Issue
Block a user