integrate slicesampler as default
This commit is contained in:
@@ -1,27 +1,28 @@
|
|||||||
import torch
|
import torch
|
||||||
from tensordict.tensordict import TensorDict
|
from tensordict.tensordict import TensorDict
|
||||||
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
||||||
from torchrl.data.replay_buffers.samplers import RandomSampler
|
|
||||||
from torchrl.envs import RandomCropTensorDict
|
|
||||||
|
|
||||||
from common.samplers import SliceSampler
|
from common.samplers import SliceSampler
|
||||||
|
|
||||||
|
|
||||||
class Buffer():
|
class Buffer():
|
||||||
"""
|
"""
|
||||||
Base class for TD-MPC2 replay buffers.
|
Replay buffer for TD-MPC2 training. Based on torchrl.
|
||||||
Uses CUDA memory if available, and CPU memory otherwise.
|
Uses CUDA memory if available, and CPU memory otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._device = torch.device('cuda')
|
self._device = torch.device('cuda')
|
||||||
self._capacity = None
|
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||||
self._max_eps = None
|
self._sampler = SliceSampler(
|
||||||
|
num_slices=self.cfg.batch_size,
|
||||||
|
end_key=None,
|
||||||
|
traj_key='episode',
|
||||||
|
truncated_key=None,
|
||||||
|
)
|
||||||
|
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||||
self._num_eps = 0
|
self._num_eps = 0
|
||||||
self._sampler = None
|
|
||||||
self._transform = None
|
|
||||||
self._batch_size = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def capacity(self):
|
def capacity(self):
|
||||||
@@ -42,21 +43,19 @@ class Buffer():
|
|||||||
sampler=self._sampler,
|
sampler=self._sampler,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
prefetch=1,
|
prefetch=1,
|
||||||
transform=self._transform,
|
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init(self, tds):
|
def _init(self, tds):
|
||||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||||
print('Buffer capacity:', self._capacity)
|
print(f'Buffer capacity: {self._capacity:,}')
|
||||||
mem_free, _ = torch.cuda.mem_get_info()
|
mem_free, _ = torch.cuda.mem_get_info()
|
||||||
bytes_per_ep = sum([
|
bytes_per_step = sum([
|
||||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||||
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
||||||
for k,v in tds.items()
|
for v in tds.values()
|
||||||
])
|
]) / len(tds)
|
||||||
print(f'Bytes per episode: {bytes_per_ep:,}')
|
total_bytes = bytes_per_step*self._capacity
|
||||||
total_bytes = bytes_per_ep*self._max_eps
|
|
||||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||||
# Heuristic: decide whether to use CUDA or CPU memory
|
# Heuristic: decide whether to use CUDA or CPU memory
|
||||||
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
|
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
|
||||||
@@ -82,68 +81,15 @@ class Buffer():
|
|||||||
task = td['task'][0] if 'task' in td.keys() else None
|
task = td['task'][0] if 'task' in td.keys() else None
|
||||||
return self._to_device(obs, action, reward, task)
|
return self._to_device(obs, action, reward, task)
|
||||||
|
|
||||||
def _add(self, td):
|
|
||||||
"""Internal function that adds episode to the buffer."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def add(self, td):
|
def add(self, td):
|
||||||
"""Add an episode to the buffer."""
|
"""Add an episode to the buffer."""
|
||||||
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
|
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
|
||||||
if self._num_eps == 0:
|
if self._num_eps == 0:
|
||||||
self._buffer = self._init(td)
|
self._buffer = self._init(td)
|
||||||
self._add(td)
|
self._buffer.extend(td)
|
||||||
self._num_eps += 1
|
self._num_eps += 1
|
||||||
return self._num_eps
|
return self._num_eps
|
||||||
|
|
||||||
def sample(self):
|
|
||||||
"""Sample a batch of sub-trajectories from the buffer."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CropBuffer(Buffer):
|
|
||||||
"""
|
|
||||||
A replay buffer that first samples trajectories, and then crops to desired length.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
super().__init__(cfg)
|
|
||||||
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
|
|
||||||
self._max_eps = self._capacity
|
|
||||||
self._sampler = RandomSampler()
|
|
||||||
self._transform = RandomCropTensorDict(cfg.horizon+1, -1)
|
|
||||||
self._batch_size = cfg.batch_size
|
|
||||||
|
|
||||||
def _add(self, td):
|
|
||||||
"""Add an episode to the buffer, with trajectories as the leading dimension."""
|
|
||||||
self._buffer.add(td)
|
|
||||||
|
|
||||||
def sample(self):
|
|
||||||
"""Sample a batch of subsequences from the buffer."""
|
|
||||||
td = self._buffer.sample().permute(1,0)
|
|
||||||
return self._prepare_batch(td)
|
|
||||||
|
|
||||||
|
|
||||||
class SliceBuffer(Buffer):
|
|
||||||
"""
|
|
||||||
A replay buffer that directly samples subsequences. More efficient than CropBuffer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cfg):
|
|
||||||
super().__init__(cfg)
|
|
||||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
|
||||||
self._max_eps = self._capacity//cfg.episode_length
|
|
||||||
self._sampler = SliceSampler(
|
|
||||||
num_slices=self.cfg.batch_size,
|
|
||||||
end_key=None,
|
|
||||||
traj_key='episode',
|
|
||||||
truncated_key=None,
|
|
||||||
)
|
|
||||||
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
|
||||||
|
|
||||||
def _add(self, td):
|
|
||||||
"""Add an episode to the buffer, with transitions as the leading dimension."""
|
|
||||||
self._buffer.extend(td)
|
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
"""Sample a batch of subsequences from the buffer."""
|
"""Sample a batch of subsequences from the buffer."""
|
||||||
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
||||||
|
|||||||
@@ -1,36 +1,18 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
from copy import copy
|
||||||
import warnings
|
from typing import Tuple
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from copy import copy, deepcopy
|
|
||||||
from multiprocessing.context import get_spawning_popen
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tensordict import MemoryMappedTensor
|
|
||||||
from tensordict.utils import NestedKey
|
from tensordict.utils import NestedKey
|
||||||
|
|
||||||
from torchrl._extension import EXTENSION_WARNING
|
|
||||||
|
|
||||||
try:
|
|
||||||
from torchrl._torchrl import (
|
|
||||||
MinSegmentTreeFp32,
|
|
||||||
MinSegmentTreeFp64,
|
|
||||||
SumSegmentTreeFp32,
|
|
||||||
SumSegmentTreeFp64,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
warnings.warn(EXTENSION_WARNING)
|
|
||||||
|
|
||||||
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
|
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
|
||||||
from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES
|
|
||||||
from torchrl.data.replay_buffers.samplers import Sampler
|
from torchrl.data.replay_buffers.samplers import Sampler
|
||||||
|
|
||||||
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
|
|
||||||
|
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
|
||||||
|
# This copy will live here until it has been included in a few torchrl stable releases
|
||||||
|
|
||||||
|
|
||||||
class SliceSampler(Sampler):
|
class SliceSampler(Sampler):
|
||||||
|
|||||||
@@ -1,64 +0,0 @@
|
|||||||
import os
|
|
||||||
os.environ['LAZY_LEGACY_OP'] = '0'
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import hydra
|
|
||||||
from tensordict.tensordict import TensorDict
|
|
||||||
|
|
||||||
from common.buffer import CropBuffer, SliceBuffer
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_name='config', config_path='.')
|
|
||||||
def test_buffer(cfg: dict):
|
|
||||||
cfg.episode_length = 11
|
|
||||||
cfg.batch_size = 8
|
|
||||||
|
|
||||||
transitions0 = [TensorDict(dict(
|
|
||||||
obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t,
|
|
||||||
action=torch.tensor([-1.]).unsqueeze(0) ** t,
|
|
||||||
reward=torch.tensor([1.]) * t,
|
|
||||||
), batch_size=(1,)) for t in range(cfg.episode_length)]
|
|
||||||
episode0 = torch.cat(transitions0)
|
|
||||||
|
|
||||||
transitions1 = [TensorDict(dict(
|
|
||||||
obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t,
|
|
||||||
action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5,
|
|
||||||
reward=torch.tensor([-1.]) * t,
|
|
||||||
), batch_size=(1,)) for t in range(cfg.episode_length)]
|
|
||||||
episode1 = torch.cat(transitions1)
|
|
||||||
|
|
||||||
crop_buffer = CropBuffer(cfg)
|
|
||||||
slice_buffer = SliceBuffer(cfg)
|
|
||||||
|
|
||||||
crop_buffer.add(episode0)
|
|
||||||
slice_buffer.add(episode0)
|
|
||||||
|
|
||||||
crop_buffer.add(episode1)
|
|
||||||
slice_buffer.add(episode1)
|
|
||||||
|
|
||||||
crop_obs, crop_action, crop_reward, _ = crop_buffer.sample()
|
|
||||||
slice_obs, slice_action, slice_reward, _ = slice_buffer.sample()
|
|
||||||
|
|
||||||
assert crop_obs.shape == slice_obs.shape
|
|
||||||
assert crop_action.shape == slice_action.shape
|
|
||||||
assert crop_reward.shape == slice_reward.shape
|
|
||||||
|
|
||||||
assert (crop_obs[1:] - crop_obs[:-1] == 1.).all()
|
|
||||||
assert (slice_obs[1:] - slice_obs[:-1] == 1.).all()
|
|
||||||
assert (crop_action.mean().abs() < 0.2)
|
|
||||||
assert (slice_action.mean().abs() < 0.2)
|
|
||||||
|
|
||||||
crop_rewards, slice_rewards = [], []
|
|
||||||
for _ in range(100_000):
|
|
||||||
_, _, crop_reward_, _ = crop_buffer.sample()
|
|
||||||
_, _, slice_reward_, _ = slice_buffer.sample()
|
|
||||||
crop_rewards.append(crop_reward_.mean())
|
|
||||||
slice_rewards.append(slice_reward_.mean())
|
|
||||||
|
|
||||||
crop_rewards = torch.tensor(crop_rewards).mean()
|
|
||||||
slice_rewards = torch.tensor(slice_rewards).mean()
|
|
||||||
assert (crop_rewards - slice_rewards) < 0.1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_buffer()
|
|
||||||
@@ -10,7 +10,7 @@ from termcolor import colored
|
|||||||
|
|
||||||
from common.parser import parse_cfg
|
from common.parser import parse_cfg
|
||||||
from common.seed import set_seed
|
from common.seed import set_seed
|
||||||
from common.buffer import CropBuffer, SliceBuffer
|
from common.buffer import Buffer
|
||||||
from envs import make_env
|
from envs import make_env
|
||||||
from tdmpc2 import TDMPC2
|
from tdmpc2 import TDMPC2
|
||||||
from trainer.offline_trainer import OfflineTrainer
|
from trainer.offline_trainer import OfflineTrainer
|
||||||
@@ -51,7 +51,7 @@ def train(cfg: dict):
|
|||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
env=make_env(cfg),
|
env=make_env(cfg),
|
||||||
agent=TDMPC2(cfg),
|
agent=TDMPC2(cfg),
|
||||||
buffer=SliceBuffer(cfg),
|
buffer=Buffer(cfg),
|
||||||
logger=Logger(cfg),
|
logger=Logger(cfg),
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|||||||
Reference in New Issue
Block a user