From 54145a4d8c4c080836ff1f186fc5a87f70c8a8c7 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Wed, 27 Dec 2023 08:49:04 -0800 Subject: [PATCH] integrate slicesampler as default --- tdmpc2/common/buffer.py | 84 +++++++-------------------------------- tdmpc2/common/samplers.py | 28 +++---------- tdmpc2/test_buffer.py | 64 ----------------------------- tdmpc2/train.py | 4 +- 4 files changed, 22 insertions(+), 158 deletions(-) delete mode 100644 tdmpc2/test_buffer.py diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 75d17d2..29cc293 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -1,27 +1,28 @@ import torch from tensordict.tensordict import TensorDict from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage -from torchrl.data.replay_buffers.samplers import RandomSampler -from torchrl.envs import RandomCropTensorDict from common.samplers import SliceSampler class Buffer(): """ - Base class for TD-MPC2 replay buffers. + Replay buffer for TD-MPC2 training. Based on torchrl. Uses CUDA memory if available, and CPU memory otherwise. """ def __init__(self, cfg): self.cfg = cfg self._device = torch.device('cuda') - self._capacity = None - self._max_eps = None + self._capacity = min(cfg.buffer_size, cfg.steps) + self._sampler = SliceSampler( + num_slices=self.cfg.batch_size, + end_key=None, + traj_key='episode', + truncated_key=None, + ) + self._batch_size = cfg.batch_size * (cfg.horizon+1) self._num_eps = 0 - self._sampler = None - self._transform = None - self._batch_size = None @property def capacity(self): @@ -42,21 +43,19 @@ class Buffer(): sampler=self._sampler, pin_memory=True, prefetch=1, - transform=self._transform, batch_size=self._batch_size, ) def _init(self, tds): """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" - print('Buffer capacity:', self._capacity) + print(f'Buffer capacity: {self._capacity:,}') mem_free, _ = torch.cuda.mem_get_info() - bytes_per_ep = sum([ + bytes_per_step = sum([ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ else sum([x.numel()*x.element_size() for x in v.values()])) \ - for k,v in tds.items() - ]) - print(f'Bytes per episode: {bytes_per_ep:,}') - total_bytes = bytes_per_ep*self._max_eps + for v in tds.values() + ]) / len(tds) + total_bytes = bytes_per_step*self._capacity print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' @@ -82,68 +81,15 @@ class Buffer(): task = td['task'][0] if 'task' in td.keys() else None return self._to_device(obs, action, reward, task) - def _add(self, td): - """Internal function that adds episode to the buffer.""" - pass - def add(self, td): """Add an episode to the buffer.""" td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps if self._num_eps == 0: self._buffer = self._init(td) - self._add(td) + self._buffer.extend(td) self._num_eps += 1 return self._num_eps - def sample(self): - """Sample a batch of sub-trajectories from the buffer.""" - pass - - -class CropBuffer(Buffer): - """ - A replay buffer that first samples trajectories, and then crops to desired length. - """ - - def __init__(self, cfg): - super().__init__(cfg) - self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length - self._max_eps = self._capacity - self._sampler = RandomSampler() - self._transform = RandomCropTensorDict(cfg.horizon+1, -1) - self._batch_size = cfg.batch_size - - def _add(self, td): - """Add an episode to the buffer, with trajectories as the leading dimension.""" - self._buffer.add(td) - - def sample(self): - """Sample a batch of subsequences from the buffer.""" - td = self._buffer.sample().permute(1,0) - return self._prepare_batch(td) - - -class SliceBuffer(Buffer): - """ - A replay buffer that directly samples subsequences. More efficient than CropBuffer. - """ - - def __init__(self, cfg): - super().__init__(cfg) - self._capacity = min(cfg.buffer_size, cfg.steps) - self._max_eps = self._capacity//cfg.episode_length - self._sampler = SliceSampler( - num_slices=self.cfg.batch_size, - end_key=None, - traj_key='episode', - truncated_key=None, - ) - self._batch_size = cfg.batch_size * (cfg.horizon+1) - - def _add(self, td): - """Add an episode to the buffer, with transitions as the leading dimension.""" - self._buffer.extend(td) - def sample(self): """Sample a batch of subsequences from the buffer.""" td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0) diff --git a/tdmpc2/common/samplers.py b/tdmpc2/common/samplers.py index a8cd72f..71cf70c 100644 --- a/tdmpc2/common/samplers.py +++ b/tdmpc2/common/samplers.py @@ -1,36 +1,18 @@ from __future__ import annotations -import json -import warnings -from abc import ABC, abstractmethod -from copy import copy, deepcopy -from multiprocessing.context import get_spawning_popen -from pathlib import Path -from typing import Any, Dict, Tuple, Union +from copy import copy +from typing import Tuple -import numpy as np import torch -from tensordict import MemoryMappedTensor from tensordict.utils import NestedKey -from torchrl._extension import EXTENSION_WARNING - -try: - from torchrl._torchrl import ( - MinSegmentTreeFp32, - MinSegmentTreeFp64, - SumSegmentTreeFp32, - SumSegmentTreeFp64, - ) -except ImportError: - warnings.warn(EXTENSION_WARNING) - from torchrl.data.replay_buffers.storages import Storage, TensorStorage -from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES from torchrl.data.replay_buffers.samplers import Sampler -_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage." + +# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html +# This copy will live here until it has been included in a few torchrl stable releases class SliceSampler(Sampler): diff --git a/tdmpc2/test_buffer.py b/tdmpc2/test_buffer.py deleted file mode 100644 index 87e9212..0000000 --- a/tdmpc2/test_buffer.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -os.environ['LAZY_LEGACY_OP'] = '0' - -import torch -import hydra -from tensordict.tensordict import TensorDict - -from common.buffer import CropBuffer, SliceBuffer - - -@hydra.main(config_name='config', config_path='.') -def test_buffer(cfg: dict): - cfg.episode_length = 11 - cfg.batch_size = 8 - - transitions0 = [TensorDict(dict( - obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t, - action=torch.tensor([-1.]).unsqueeze(0) ** t, - reward=torch.tensor([1.]) * t, - ), batch_size=(1,)) for t in range(cfg.episode_length)] - episode0 = torch.cat(transitions0) - - transitions1 = [TensorDict(dict( - obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t, - action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5, - reward=torch.tensor([-1.]) * t, - ), batch_size=(1,)) for t in range(cfg.episode_length)] - episode1 = torch.cat(transitions1) - - crop_buffer = CropBuffer(cfg) - slice_buffer = SliceBuffer(cfg) - - crop_buffer.add(episode0) - slice_buffer.add(episode0) - - crop_buffer.add(episode1) - slice_buffer.add(episode1) - - crop_obs, crop_action, crop_reward, _ = crop_buffer.sample() - slice_obs, slice_action, slice_reward, _ = slice_buffer.sample() - - assert crop_obs.shape == slice_obs.shape - assert crop_action.shape == slice_action.shape - assert crop_reward.shape == slice_reward.shape - - assert (crop_obs[1:] - crop_obs[:-1] == 1.).all() - assert (slice_obs[1:] - slice_obs[:-1] == 1.).all() - assert (crop_action.mean().abs() < 0.2) - assert (slice_action.mean().abs() < 0.2) - - crop_rewards, slice_rewards = [], [] - for _ in range(100_000): - _, _, crop_reward_, _ = crop_buffer.sample() - _, _, slice_reward_, _ = slice_buffer.sample() - crop_rewards.append(crop_reward_.mean()) - slice_rewards.append(slice_reward_.mean()) - - crop_rewards = torch.tensor(crop_rewards).mean() - slice_rewards = torch.tensor(slice_rewards).mean() - assert (crop_rewards - slice_rewards) < 0.1 - - -if __name__ == '__main__': - test_buffer() diff --git a/tdmpc2/train.py b/tdmpc2/train.py index b091bec..5953bb2 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -10,7 +10,7 @@ from termcolor import colored from common.parser import parse_cfg from common.seed import set_seed -from common.buffer import CropBuffer, SliceBuffer +from common.buffer import Buffer from envs import make_env from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer @@ -51,7 +51,7 @@ def train(cfg: dict): cfg=cfg, env=make_env(cfg), agent=TDMPC2(cfg), - buffer=SliceBuffer(cfg), + buffer=Buffer(cfg), logger=Logger(cfg), ) trainer.train()