integrate slicesampler as default

This commit is contained in:
Nicklas Hansen
2023-12-27 08:49:04 -08:00
parent 2f86a1e4d8
commit 54145a4d8c
4 changed files with 22 additions and 158 deletions

View File

@@ -1,27 +1,28 @@
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict
from common.samplers import SliceSampler from common.samplers import SliceSampler
class Buffer(): class Buffer():
""" """
Base class for TD-MPC2 replay buffers. Replay buffer for TD-MPC2 training. Based on torchrl.
Uses CUDA memory if available, and CPU memory otherwise. Uses CUDA memory if available, and CPU memory otherwise.
""" """
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._device = torch.device('cuda') self._device = torch.device('cuda')
self._capacity = None self._capacity = min(cfg.buffer_size, cfg.steps)
self._max_eps = None self._sampler = SliceSampler(
num_slices=self.cfg.batch_size,
end_key=None,
traj_key='episode',
truncated_key=None,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0 self._num_eps = 0
self._sampler = None
self._transform = None
self._batch_size = None
@property @property
def capacity(self): def capacity(self):
@@ -42,21 +43,19 @@ class Buffer():
sampler=self._sampler, sampler=self._sampler,
pin_memory=True, pin_memory=True,
prefetch=1, prefetch=1,
transform=self._transform,
batch_size=self._batch_size, batch_size=self._batch_size,
) )
def _init(self, tds): def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements.""" """Initialize the replay buffer. Use the first episode to estimate storage requirements."""
print('Buffer capacity:', self._capacity) print(f'Buffer capacity: {self._capacity:,}')
mem_free, _ = torch.cuda.mem_get_info() mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([ bytes_per_step = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \ else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items() for v in tds.values()
]) ]) / len(tds)
print(f'Bytes per episode: {bytes_per_ep:,}') total_bytes = bytes_per_step*self._capacity
total_bytes = bytes_per_ep*self._max_eps
print(f'Storage required: {total_bytes/1e9:.2f} GB') print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory # Heuristic: decide whether to use CUDA or CPU memory
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
@@ -82,68 +81,15 @@ class Buffer():
task = td['task'][0] if 'task' in td.keys() else None task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, task) return self._to_device(obs, action, reward, task)
def _add(self, td):
"""Internal function that adds episode to the buffer."""
pass
def add(self, td): def add(self, td):
"""Add an episode to the buffer.""" """Add an episode to the buffer."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
if self._num_eps == 0: if self._num_eps == 0:
self._buffer = self._init(td) self._buffer = self._init(td)
self._add(td) self._buffer.extend(td)
self._num_eps += 1 self._num_eps += 1
return self._num_eps return self._num_eps
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
pass
class CropBuffer(Buffer):
"""
A replay buffer that first samples trajectories, and then crops to desired length.
"""
def __init__(self, cfg):
super().__init__(cfg)
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
self._max_eps = self._capacity
self._sampler = RandomSampler()
self._transform = RandomCropTensorDict(cfg.horizon+1, -1)
self._batch_size = cfg.batch_size
def _add(self, td):
"""Add an episode to the buffer, with trajectories as the leading dimension."""
self._buffer.add(td)
def sample(self):
"""Sample a batch of subsequences from the buffer."""
td = self._buffer.sample().permute(1,0)
return self._prepare_batch(td)
class SliceBuffer(Buffer):
"""
A replay buffer that directly samples subsequences. More efficient than CropBuffer.
"""
def __init__(self, cfg):
super().__init__(cfg)
self._capacity = min(cfg.buffer_size, cfg.steps)
self._max_eps = self._capacity//cfg.episode_length
self._sampler = SliceSampler(
num_slices=self.cfg.batch_size,
end_key=None,
traj_key='episode',
truncated_key=None,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
def _add(self, td):
"""Add an episode to the buffer, with transitions as the leading dimension."""
self._buffer.extend(td)
def sample(self): def sample(self):
"""Sample a batch of subsequences from the buffer.""" """Sample a batch of subsequences from the buffer."""
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0) td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)

View File

@@ -1,36 +1,18 @@
from __future__ import annotations from __future__ import annotations
import json from copy import copy
import warnings from typing import Tuple
from abc import ABC, abstractmethod
from copy import copy, deepcopy
from multiprocessing.context import get_spawning_popen
from pathlib import Path
from typing import Any, Dict, Tuple, Union
import numpy as np
import torch import torch
from tensordict import MemoryMappedTensor
from tensordict.utils import NestedKey from tensordict.utils import NestedKey
from torchrl._extension import EXTENSION_WARNING
try:
from torchrl._torchrl import (
MinSegmentTreeFp32,
MinSegmentTreeFp64,
SumSegmentTreeFp32,
SumSegmentTreeFp64,
)
except ImportError:
warnings.warn(EXTENSION_WARNING)
from torchrl.data.replay_buffers.storages import Storage, TensorStorage from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES
from torchrl.data.replay_buffers.samplers import Sampler from torchrl.data.replay_buffers.samplers import Sampler
_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage."
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
# This copy will live here until it has been included in a few torchrl stable releases
class SliceSampler(Sampler): class SliceSampler(Sampler):

View File

@@ -1,64 +0,0 @@
import os
os.environ['LAZY_LEGACY_OP'] = '0'
import torch
import hydra
from tensordict.tensordict import TensorDict
from common.buffer import CropBuffer, SliceBuffer
@hydra.main(config_name='config', config_path='.')
def test_buffer(cfg: dict):
cfg.episode_length = 11
cfg.batch_size = 8
transitions0 = [TensorDict(dict(
obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t,
action=torch.tensor([-1.]).unsqueeze(0) ** t,
reward=torch.tensor([1.]) * t,
), batch_size=(1,)) for t in range(cfg.episode_length)]
episode0 = torch.cat(transitions0)
transitions1 = [TensorDict(dict(
obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t,
action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5,
reward=torch.tensor([-1.]) * t,
), batch_size=(1,)) for t in range(cfg.episode_length)]
episode1 = torch.cat(transitions1)
crop_buffer = CropBuffer(cfg)
slice_buffer = SliceBuffer(cfg)
crop_buffer.add(episode0)
slice_buffer.add(episode0)
crop_buffer.add(episode1)
slice_buffer.add(episode1)
crop_obs, crop_action, crop_reward, _ = crop_buffer.sample()
slice_obs, slice_action, slice_reward, _ = slice_buffer.sample()
assert crop_obs.shape == slice_obs.shape
assert crop_action.shape == slice_action.shape
assert crop_reward.shape == slice_reward.shape
assert (crop_obs[1:] - crop_obs[:-1] == 1.).all()
assert (slice_obs[1:] - slice_obs[:-1] == 1.).all()
assert (crop_action.mean().abs() < 0.2)
assert (slice_action.mean().abs() < 0.2)
crop_rewards, slice_rewards = [], []
for _ in range(100_000):
_, _, crop_reward_, _ = crop_buffer.sample()
_, _, slice_reward_, _ = slice_buffer.sample()
crop_rewards.append(crop_reward_.mean())
slice_rewards.append(slice_reward_.mean())
crop_rewards = torch.tensor(crop_rewards).mean()
slice_rewards = torch.tensor(slice_rewards).mean()
assert (crop_rewards - slice_rewards) < 0.1
if __name__ == '__main__':
test_buffer()

View File

@@ -10,7 +10,7 @@ from termcolor import colored
from common.parser import parse_cfg from common.parser import parse_cfg
from common.seed import set_seed from common.seed import set_seed
from common.buffer import CropBuffer, SliceBuffer from common.buffer import Buffer
from envs import make_env from envs import make_env
from tdmpc2 import TDMPC2 from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer from trainer.offline_trainer import OfflineTrainer
@@ -51,7 +51,7 @@ def train(cfg: dict):
cfg=cfg, cfg=cfg,
env=make_env(cfg), env=make_env(cfg),
agent=TDMPC2(cfg), agent=TDMPC2(cfg),
buffer=SliceBuffer(cfg), buffer=Buffer(cfg),
logger=Logger(cfg), logger=Logger(cfg),
) )
trainer.train() trainer.train()