integrate slicesampler as default

This commit is contained in:
Nicklas Hansen
2023-12-27 08:49:04 -08:00
parent 2f86a1e4d8
commit 54145a4d8c
4 changed files with 22 additions and 158 deletions

View File

@@ -1,27 +1,28 @@
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict
from common.samplers import SliceSampler
class Buffer():
"""
Base class for TD-MPC2 replay buffers.
Replay buffer for TD-MPC2 training. Based on torchrl.
Uses CUDA memory if available, and CPU memory otherwise.
"""
def __init__(self, cfg):
self.cfg = cfg
self._device = torch.device('cuda')
self._capacity = None
self._max_eps = None
self._capacity = min(cfg.buffer_size, cfg.steps)
self._sampler = SliceSampler(
num_slices=self.cfg.batch_size,
end_key=None,
traj_key='episode',
truncated_key=None,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0
self._sampler = None
self._transform = None
self._batch_size = None
@property
def capacity(self):
@@ -42,21 +43,19 @@ class Buffer():
sampler=self._sampler,
pin_memory=True,
prefetch=1,
transform=self._transform,
batch_size=self._batch_size,
)
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
print('Buffer capacity:', self._capacity)
print(f'Buffer capacity: {self._capacity:,}')
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([
bytes_per_step = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._max_eps
for v in tds.values()
]) / len(tds)
total_bytes = bytes_per_step*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
@@ -82,68 +81,15 @@ class Buffer():
task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, task)
def _add(self, td):
"""Internal function that adds episode to the buffer."""
pass
def add(self, td):
"""Add an episode to the buffer."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
if self._num_eps == 0:
self._buffer = self._init(td)
self._add(td)
self._buffer.extend(td)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
pass
class CropBuffer(Buffer):
"""
A replay buffer that first samples trajectories, and then crops to desired length.
"""
def __init__(self, cfg):
super().__init__(cfg)
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
self._max_eps = self._capacity
self._sampler = RandomSampler()
self._transform = RandomCropTensorDict(cfg.horizon+1, -1)
self._batch_size = cfg.batch_size
def _add(self, td):
"""Add an episode to the buffer, with trajectories as the leading dimension."""
self._buffer.add(td)
def sample(self):
"""Sample a batch of subsequences from the buffer."""
td = self._buffer.sample().permute(1,0)
return self._prepare_batch(td)
class SliceBuffer(Buffer):
"""
A replay buffer that directly samples subsequences. More efficient than CropBuffer.
"""
def __init__(self, cfg):
super().__init__(cfg)
self._capacity = min(cfg.buffer_size, cfg.steps)
self._max_eps = self._capacity//cfg.episode_length
self._sampler = SliceSampler(
num_slices=self.cfg.batch_size,
end_key=None,
traj_key='episode',
truncated_key=None,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
def _add(self, td):
"""Add an episode to the buffer, with transitions as the leading dimension."""
self._buffer.extend(td)
def sample(self):
"""Sample a batch of subsequences from the buffer."""
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)

View File

@@ -1,36 +1,18 @@
from __future__ import annotations
import json
import warnings
from abc import ABC, abstractmethod
from copy import copy, deepcopy
from multiprocessing.context import get_spawning_popen
from pathlib import Path
from typing import Any, Dict, Tuple, Union
from copy import copy
from typing import Tuple
import numpy as np
import torch
from tensordict import MemoryMappedTensor
from tensordict.utils import NestedKey
from torchrl._extension import EXTENSION_WARNING
try:
from torchrl._torchrl import (
MinSegmentTreeFp32,
MinSegmentTreeFp64,
SumSegmentTreeFp32,
SumSegmentTreeFp64,
)
except ImportError:
warnings.warn(EXTENSION_WARNING)
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES
from torchrl.data.replay_buffers.samplers import Sampler
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
# This copy will live here until it has been included in a few torchrl stable releases
class SliceSampler(Sampler):

View File

@@ -1,64 +0,0 @@
import os
os.environ['LAZY_LEGACY_OP'] = '0'
import torch
import hydra
from tensordict.tensordict import TensorDict
from common.buffer import CropBuffer, SliceBuffer
@hydra.main(config_name='config', config_path='.')
def test_buffer(cfg: dict):
cfg.episode_length = 11
cfg.batch_size = 8
transitions0 = [TensorDict(dict(
obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t,
action=torch.tensor([-1.]).unsqueeze(0) ** t,
reward=torch.tensor([1.]) * t,
), batch_size=(1,)) for t in range(cfg.episode_length)]
episode0 = torch.cat(transitions0)
transitions1 = [TensorDict(dict(
obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t,
action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5,
reward=torch.tensor([-1.]) * t,
), batch_size=(1,)) for t in range(cfg.episode_length)]
episode1 = torch.cat(transitions1)
crop_buffer = CropBuffer(cfg)
slice_buffer = SliceBuffer(cfg)
crop_buffer.add(episode0)
slice_buffer.add(episode0)
crop_buffer.add(episode1)
slice_buffer.add(episode1)
crop_obs, crop_action, crop_reward, _ = crop_buffer.sample()
slice_obs, slice_action, slice_reward, _ = slice_buffer.sample()
assert crop_obs.shape == slice_obs.shape
assert crop_action.shape == slice_action.shape
assert crop_reward.shape == slice_reward.shape
assert (crop_obs[1:] - crop_obs[:-1] == 1.).all()
assert (slice_obs[1:] - slice_obs[:-1] == 1.).all()
assert (crop_action.mean().abs() < 0.2)
assert (slice_action.mean().abs() < 0.2)
crop_rewards, slice_rewards = [], []
for _ in range(100_000):
_, _, crop_reward_, _ = crop_buffer.sample()
_, _, slice_reward_, _ = slice_buffer.sample()
crop_rewards.append(crop_reward_.mean())
slice_rewards.append(slice_reward_.mean())
crop_rewards = torch.tensor(crop_rewards).mean()
slice_rewards = torch.tensor(slice_rewards).mean()
assert (crop_rewards - slice_rewards) < 0.1
if __name__ == '__main__':
test_buffer()

View File

@@ -10,7 +10,7 @@ from termcolor import colored
from common.parser import parse_cfg
from common.seed import set_seed
from common.buffer import CropBuffer, SliceBuffer
from common.buffer import Buffer
from envs import make_env
from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer
@@ -51,7 +51,7 @@ def train(cfg: dict):
cfg=cfg,
env=make_env(cfg),
agent=TDMPC2(cfg),
buffer=SliceBuffer(cfg),
buffer=Buffer(cfg),
logger=Logger(cfg),
)
trainer.train()