compare new/old buffers

This commit is contained in:
Nicklas Hansen
2023-12-22 13:34:12 -08:00
parent fea0936e69
commit 34ea3662cd
4 changed files with 144 additions and 40 deletions

View File

@@ -2,7 +2,6 @@ from pathlib import Path
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.envs import RandomCropTensorDict, Transform, Compose
from common.logger import make_dir
from common.samplers import SliceSampler
@@ -19,7 +18,6 @@ class Buffer():
self._device = torch.device('cuda')
self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1)
self._capacity = min(cfg.buffer_size, cfg.steps)
self._num_steps = 0
self._num_eps = 0
@property
@@ -27,11 +25,6 @@ class Buffer():
"""Return the capacity of the buffer."""
return self._capacity
@property
def num_steps(self):
"""Return the number of steps in the buffer."""
return self._num_steps
@property
def num_eps(self):
"""Return the number of episodes in the buffer."""
@@ -45,7 +38,8 @@ class Buffer():
storage=storage,
sampler=SliceSampler(
slice_len=self.cfg.horizon+1,
end_key='done',
end_key=None,
traj_key='episode',
truncated_key=None,
),
pin_memory=True,
@@ -53,12 +47,16 @@ class Buffer():
batch_size=self.cfg.batch_size,
)
def _init(self, td):
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_step = sum([x.numel()*x.element_size() for x in td[0].values()])
print(f'Bytes per step: {bytes_per_step:,}')
total_bytes = bytes_per_step*self._capacity
bytes_per_ep = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
@@ -72,23 +70,20 @@ class Buffer():
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
)
def add(self, td):
def add(self, tds):
"""Add a step to the buffer."""
done = bool(td['done'].any())
if done:
self._num_eps +=1
td['episode'] = torch.ones_like(td['done']) * self._num_eps
td['step'] = torch.arange(0, len(td))
if self._num_steps == 0:
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_steps += 1
return self._num_steps
tds['episode'] = torch.ones_like(tds['reward'], dtype=torch.int64) * self._num_eps
tds['step'] = torch.arange(0, len(tds))
if self._num_eps == 0:
self._buffer = self._init(tds)
self._buffer.extend(tds)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
td = self._buffer.sample(batch_size=self._batch_size) \
.reshape(-1, self.cfg.horizon+1).permute(1, 0)
.view(-1, self.cfg.horizon+1).permute(1, 0)
obs = td['obs'].to(self._device, non_blocking=True)
action = td['action'][1:].to(self._device, non_blocking=True)
reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True)

View File

@@ -0,0 +1,115 @@
from pathlib import Path
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict, Transform, Compose
from common.logger import make_dir
class DataPrepTransform(Transform):
"""
Preprocesses data for TD-MPC2 training.
Replay data is expected to be a TensorDict with the following keys:
obs: observations
action: actions
reward: rewards
task: task IDs (optional)
A TensorDict with T time steps has T+1 observations and T actions and rewards.
The first actions and rewards in each TensorDict are dummies and should be ignored.
"""
def __init__(self):
super().__init__([])
def forward(self, td):
td = td.permute(1,0)
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
class Buffer():
"""
Create a replay buffer for TD-MPC2 training.
Uses CUDA memory if available, and CPU memory otherwise.
"""
def __init__(self, cfg):
self.cfg = cfg
self._device = torch.device('cuda')
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
self._num_eps = 0
@property
def capacity(self):
"""Return the capacity of the buffer."""
return self._capacity
@property
def num_eps(self):
"""Return the number of episodes in the buffer."""
return self._num_eps
def _reserve_buffer(self, storage):
"""
Reserve a buffer with the given storage.
Uses the RandomSampler to sample trajectories,
and the RandomCropTensorDict transform to crop trajectories to the desired length.
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
"""
return ReplayBuffer(
storage=storage,
sampler=RandomSampler(),
pin_memory=True,
prefetch=1,
transform=Compose(
RandomCropTensorDict(self.cfg.horizon+1, -1),
DataPrepTransform(),
),
batch_size=self.cfg.batch_size,
)
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
print('Using CPU memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
)
else: # Sufficient CUDA memory
print('Using CUDA memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
)
def add(self, tds):
"""Add an episode to the buffer. All episodes are expected to have the same length."""
if self._num_eps == 0:
self._buffer = self._init(tds)
self._buffer.add(tds)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
return obs.to(self._device, non_blocking=True), \
action.to(self._device, non_blocking=True), \
reward.to(self._device, non_blocking=True), \
task.to(self._device, non_blocking=True) if task is not None else None
def save(self):
"""Save the buffer to disk. Useful for storing offline datasets."""
td = self._buffer._storage._storage.cpu()
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
torch.save(td, fp)

View File

@@ -1,5 +1,6 @@
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ['LAZY_LEGACY_OP'] = 0
import warnings
warnings.filterwarnings('ignore')
import torch
@@ -10,6 +11,7 @@ from termcolor import colored
from common.parser import parse_cfg
from common.seed import set_seed
from common.buffer import Buffer
# from common.legacy_buffer import Buffer
from envs import make_env
from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer

View File

@@ -14,7 +14,6 @@ class OnlineTrainer(Trainer):
super().__init__(*args, **kwargs)
self._step = 0
self._ep_idx = 0
self._ep_reward = 0
self._start_time = time()
def common_metrics(self):
@@ -22,7 +21,6 @@ class OnlineTrainer(Trainer):
return dict(
step=self._step,
episode=self._ep_idx,
episode_reward=self._ep_reward,
total_time=time() - self._start_time,
)
@@ -49,24 +47,20 @@ class OnlineTrainer(Trainer):
episode_success=np.nanmean(ep_successes),
)
def to_td(self, obs, action=None, reward=None, done=None):
def to_td(self, obs, action=None, reward=None):
"""Creates a TensorDict for a new episode."""
if isinstance(obs, dict):
obs = TensorDict({k: v for k,v in obs.items()}, batch_size=()).cpu()
obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu()
else:
obs = obs.cpu()
obs = obs.unsqueeze(0).cpu()
if action is None:
action = torch.empty_like(self.env.rand_act())
if reward is None:
reward = torch.tensor(float('nan'))
if done is None:
done = False
done = torch.tensor(done)
td = TensorDict(dict(
obs=obs.unsqueeze(0),
obs=obs,
action=action.unsqueeze(0),
reward=reward.unsqueeze(0),
done=done.unsqueeze(0),
), batch_size=(1,))
return td
@@ -89,16 +83,15 @@ class OnlineTrainer(Trainer):
if self._step > 0:
train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'],
)
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')
self._ep_idx += 1
self.buffer.add(torch.cat(self._tds))
self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset()
self._tds = [self.to_td(obs)]
self._ep_reward = 0
# Collect experience
if self._step > self.cfg.seed_steps:
@@ -106,8 +99,7 @@ class OnlineTrainer(Trainer):
else:
action = self.env.rand_act()
obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward, done))
self._ep_reward += reward
self._tds.append(self.to_td(obs, action, reward))
# Update agent
if self._step >= self.cfg.seed_steps: