integrate slicesampler as default
This commit is contained in:
@@ -1,27 +1,28 @@
|
||||
import torch
|
||||
from tensordict.tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
||||
from torchrl.data.replay_buffers.samplers import RandomSampler
|
||||
from torchrl.envs import RandomCropTensorDict
|
||||
|
||||
from common.samplers import SliceSampler
|
||||
|
||||
|
||||
class Buffer():
|
||||
"""
|
||||
Base class for TD-MPC2 replay buffers.
|
||||
Replay buffer for TD-MPC2 training. Based on torchrl.
|
||||
Uses CUDA memory if available, and CPU memory otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self._device = torch.device('cuda')
|
||||
self._capacity = None
|
||||
self._max_eps = None
|
||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||
self._sampler = SliceSampler(
|
||||
num_slices=self.cfg.batch_size,
|
||||
end_key=None,
|
||||
traj_key='episode',
|
||||
truncated_key=None,
|
||||
)
|
||||
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||
self._num_eps = 0
|
||||
self._sampler = None
|
||||
self._transform = None
|
||||
self._batch_size = None
|
||||
|
||||
@property
|
||||
def capacity(self):
|
||||
@@ -42,21 +43,19 @@ class Buffer():
|
||||
sampler=self._sampler,
|
||||
pin_memory=True,
|
||||
prefetch=1,
|
||||
transform=self._transform,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
|
||||
def _init(self, tds):
|
||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||
print('Buffer capacity:', self._capacity)
|
||||
print(f'Buffer capacity: {self._capacity:,}')
|
||||
mem_free, _ = torch.cuda.mem_get_info()
|
||||
bytes_per_ep = sum([
|
||||
bytes_per_step = sum([
|
||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
||||
for k,v in tds.items()
|
||||
])
|
||||
print(f'Bytes per episode: {bytes_per_ep:,}')
|
||||
total_bytes = bytes_per_ep*self._max_eps
|
||||
for v in tds.values()
|
||||
]) / len(tds)
|
||||
total_bytes = bytes_per_step*self._capacity
|
||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||
# Heuristic: decide whether to use CUDA or CPU memory
|
||||
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
|
||||
@@ -82,68 +81,15 @@ class Buffer():
|
||||
task = td['task'][0] if 'task' in td.keys() else None
|
||||
return self._to_device(obs, action, reward, task)
|
||||
|
||||
def _add(self, td):
|
||||
"""Internal function that adds episode to the buffer."""
|
||||
pass
|
||||
|
||||
def add(self, td):
|
||||
"""Add an episode to the buffer."""
|
||||
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
|
||||
if self._num_eps == 0:
|
||||
self._buffer = self._init(td)
|
||||
self._add(td)
|
||||
self._buffer.extend(td)
|
||||
self._num_eps += 1
|
||||
return self._num_eps
|
||||
|
||||
def sample(self):
|
||||
"""Sample a batch of sub-trajectories from the buffer."""
|
||||
pass
|
||||
|
||||
|
||||
class CropBuffer(Buffer):
|
||||
"""
|
||||
A replay buffer that first samples trajectories, and then crops to desired length.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
|
||||
self._max_eps = self._capacity
|
||||
self._sampler = RandomSampler()
|
||||
self._transform = RandomCropTensorDict(cfg.horizon+1, -1)
|
||||
self._batch_size = cfg.batch_size
|
||||
|
||||
def _add(self, td):
|
||||
"""Add an episode to the buffer, with trajectories as the leading dimension."""
|
||||
self._buffer.add(td)
|
||||
|
||||
def sample(self):
|
||||
"""Sample a batch of subsequences from the buffer."""
|
||||
td = self._buffer.sample().permute(1,0)
|
||||
return self._prepare_batch(td)
|
||||
|
||||
|
||||
class SliceBuffer(Buffer):
|
||||
"""
|
||||
A replay buffer that directly samples subsequences. More efficient than CropBuffer.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||
self._max_eps = self._capacity//cfg.episode_length
|
||||
self._sampler = SliceSampler(
|
||||
num_slices=self.cfg.batch_size,
|
||||
end_key=None,
|
||||
traj_key='episode',
|
||||
truncated_key=None,
|
||||
)
|
||||
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||
|
||||
def _add(self, td):
|
||||
"""Add an episode to the buffer, with transitions as the leading dimension."""
|
||||
self._buffer.extend(td)
|
||||
|
||||
def sample(self):
|
||||
"""Sample a batch of subsequences from the buffer."""
|
||||
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
||||
|
||||
@@ -1,36 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import copy, deepcopy
|
||||
from multiprocessing.context import get_spawning_popen
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from copy import copy
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tensordict import MemoryMappedTensor
|
||||
from tensordict.utils import NestedKey
|
||||
|
||||
from torchrl._extension import EXTENSION_WARNING
|
||||
|
||||
try:
|
||||
from torchrl._torchrl import (
|
||||
MinSegmentTreeFp32,
|
||||
MinSegmentTreeFp64,
|
||||
SumSegmentTreeFp32,
|
||||
SumSegmentTreeFp64,
|
||||
)
|
||||
except ImportError:
|
||||
warnings.warn(EXTENSION_WARNING)
|
||||
|
||||
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
|
||||
from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
|
||||
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
|
||||
|
||||
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
|
||||
# This copy will live here until it has been included in a few torchrl stable releases
|
||||
|
||||
|
||||
class SliceSampler(Sampler):
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
import os
|
||||
os.environ['LAZY_LEGACY_OP'] = '0'
|
||||
|
||||
import torch
|
||||
import hydra
|
||||
from tensordict.tensordict import TensorDict
|
||||
|
||||
from common.buffer import CropBuffer, SliceBuffer
|
||||
|
||||
|
||||
@hydra.main(config_name='config', config_path='.')
|
||||
def test_buffer(cfg: dict):
|
||||
cfg.episode_length = 11
|
||||
cfg.batch_size = 8
|
||||
|
||||
transitions0 = [TensorDict(dict(
|
||||
obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t,
|
||||
action=torch.tensor([-1.]).unsqueeze(0) ** t,
|
||||
reward=torch.tensor([1.]) * t,
|
||||
), batch_size=(1,)) for t in range(cfg.episode_length)]
|
||||
episode0 = torch.cat(transitions0)
|
||||
|
||||
transitions1 = [TensorDict(dict(
|
||||
obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t,
|
||||
action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5,
|
||||
reward=torch.tensor([-1.]) * t,
|
||||
), batch_size=(1,)) for t in range(cfg.episode_length)]
|
||||
episode1 = torch.cat(transitions1)
|
||||
|
||||
crop_buffer = CropBuffer(cfg)
|
||||
slice_buffer = SliceBuffer(cfg)
|
||||
|
||||
crop_buffer.add(episode0)
|
||||
slice_buffer.add(episode0)
|
||||
|
||||
crop_buffer.add(episode1)
|
||||
slice_buffer.add(episode1)
|
||||
|
||||
crop_obs, crop_action, crop_reward, _ = crop_buffer.sample()
|
||||
slice_obs, slice_action, slice_reward, _ = slice_buffer.sample()
|
||||
|
||||
assert crop_obs.shape == slice_obs.shape
|
||||
assert crop_action.shape == slice_action.shape
|
||||
assert crop_reward.shape == slice_reward.shape
|
||||
|
||||
assert (crop_obs[1:] - crop_obs[:-1] == 1.).all()
|
||||
assert (slice_obs[1:] - slice_obs[:-1] == 1.).all()
|
||||
assert (crop_action.mean().abs() < 0.2)
|
||||
assert (slice_action.mean().abs() < 0.2)
|
||||
|
||||
crop_rewards, slice_rewards = [], []
|
||||
for _ in range(100_000):
|
||||
_, _, crop_reward_, _ = crop_buffer.sample()
|
||||
_, _, slice_reward_, _ = slice_buffer.sample()
|
||||
crop_rewards.append(crop_reward_.mean())
|
||||
slice_rewards.append(slice_reward_.mean())
|
||||
|
||||
crop_rewards = torch.tensor(crop_rewards).mean()
|
||||
slice_rewards = torch.tensor(slice_rewards).mean()
|
||||
assert (crop_rewards - slice_rewards) < 0.1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_buffer()
|
||||
@@ -10,7 +10,7 @@ from termcolor import colored
|
||||
|
||||
from common.parser import parse_cfg
|
||||
from common.seed import set_seed
|
||||
from common.buffer import CropBuffer, SliceBuffer
|
||||
from common.buffer import Buffer
|
||||
from envs import make_env
|
||||
from tdmpc2 import TDMPC2
|
||||
from trainer.offline_trainer import OfflineTrainer
|
||||
@@ -51,7 +51,7 @@ def train(cfg: dict):
|
||||
cfg=cfg,
|
||||
env=make_env(cfg),
|
||||
agent=TDMPC2(cfg),
|
||||
buffer=SliceBuffer(cfg),
|
||||
buffer=Buffer(cfg),
|
||||
logger=Logger(cfg),
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
Reference in New Issue
Block a user