compare new/old buffers

This commit is contained in:
Nicklas Hansen
2023-12-22 13:34:12 -08:00
parent fea0936e69
commit 34ea3662cd
4 changed files with 144 additions and 40 deletions

View File

@@ -2,7 +2,6 @@ from pathlib import Path
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.envs import RandomCropTensorDict, Transform, Compose
from common.logger import make_dir from common.logger import make_dir
from common.samplers import SliceSampler from common.samplers import SliceSampler
@@ -19,7 +18,6 @@ class Buffer():
self._device = torch.device('cuda') self._device = torch.device('cuda')
self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1) self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1)
self._capacity = min(cfg.buffer_size, cfg.steps) self._capacity = min(cfg.buffer_size, cfg.steps)
self._num_steps = 0
self._num_eps = 0 self._num_eps = 0
@property @property
@@ -27,11 +25,6 @@ class Buffer():
"""Return the capacity of the buffer.""" """Return the capacity of the buffer."""
return self._capacity return self._capacity
@property
def num_steps(self):
"""Return the number of steps in the buffer."""
return self._num_steps
@property @property
def num_eps(self): def num_eps(self):
"""Return the number of episodes in the buffer.""" """Return the number of episodes in the buffer."""
@@ -45,7 +38,8 @@ class Buffer():
storage=storage, storage=storage,
sampler=SliceSampler( sampler=SliceSampler(
slice_len=self.cfg.horizon+1, slice_len=self.cfg.horizon+1,
end_key='done', end_key=None,
traj_key='episode',
truncated_key=None, truncated_key=None,
), ),
pin_memory=True, pin_memory=True,
@@ -53,12 +47,16 @@ class Buffer():
batch_size=self.cfg.batch_size, batch_size=self.cfg.batch_size,
) )
def _init(self, td): def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements.""" """Initialize the replay buffer. Use the first episode to estimate storage requirements."""
mem_free, _ = torch.cuda.mem_get_info() mem_free, _ = torch.cuda.mem_get_info()
bytes_per_step = sum([x.numel()*x.element_size() for x in td[0].values()]) bytes_per_ep = sum([
print(f'Bytes per step: {bytes_per_step:,}') (v.numel()*v.element_size() if not isinstance(v, TensorDict) \
total_bytes = bytes_per_step*self._capacity else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB') print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory # Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
@@ -72,23 +70,20 @@ class Buffer():
LazyTensorStorage(self._capacity, device=torch.device('cuda')) LazyTensorStorage(self._capacity, device=torch.device('cuda'))
) )
def add(self, td): def add(self, tds):
"""Add a step to the buffer.""" """Add a step to the buffer."""
done = bool(td['done'].any()) tds['episode'] = torch.ones_like(tds['reward'], dtype=torch.int64) * self._num_eps
if done: tds['step'] = torch.arange(0, len(tds))
self._num_eps +=1 if self._num_eps == 0:
td['episode'] = torch.ones_like(td['done']) * self._num_eps self._buffer = self._init(tds)
td['step'] = torch.arange(0, len(td)) self._buffer.extend(tds)
if self._num_steps == 0: self._num_eps += 1
self._buffer = self._init(td) return self._num_eps
self._buffer.extend(td)
self._num_steps += 1
return self._num_steps
def sample(self): def sample(self):
"""Sample a batch of sub-trajectories from the buffer.""" """Sample a batch of sub-trajectories from the buffer."""
td = self._buffer.sample(batch_size=self._batch_size) \ td = self._buffer.sample(batch_size=self._batch_size) \
.reshape(-1, self.cfg.horizon+1).permute(1, 0) .view(-1, self.cfg.horizon+1).permute(1, 0)
obs = td['obs'].to(self._device, non_blocking=True) obs = td['obs'].to(self._device, non_blocking=True)
action = td['action'][1:].to(self._device, non_blocking=True) action = td['action'][1:].to(self._device, non_blocking=True)
reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True) reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True)

View File

@@ -0,0 +1,115 @@
from pathlib import Path
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict, Transform, Compose
from common.logger import make_dir
class DataPrepTransform(Transform):
"""
Preprocesses data for TD-MPC2 training.
Replay data is expected to be a TensorDict with the following keys:
obs: observations
action: actions
reward: rewards
task: task IDs (optional)
A TensorDict with T time steps has T+1 observations and T actions and rewards.
The first actions and rewards in each TensorDict are dummies and should be ignored.
"""
def __init__(self):
super().__init__([])
def forward(self, td):
td = td.permute(1,0)
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
class Buffer():
"""
Create a replay buffer for TD-MPC2 training.
Uses CUDA memory if available, and CPU memory otherwise.
"""
def __init__(self, cfg):
self.cfg = cfg
self._device = torch.device('cuda')
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
self._num_eps = 0
@property
def capacity(self):
"""Return the capacity of the buffer."""
return self._capacity
@property
def num_eps(self):
"""Return the number of episodes in the buffer."""
return self._num_eps
def _reserve_buffer(self, storage):
"""
Reserve a buffer with the given storage.
Uses the RandomSampler to sample trajectories,
and the RandomCropTensorDict transform to crop trajectories to the desired length.
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
"""
return ReplayBuffer(
storage=storage,
sampler=RandomSampler(),
pin_memory=True,
prefetch=1,
transform=Compose(
RandomCropTensorDict(self.cfg.horizon+1, -1),
DataPrepTransform(),
),
batch_size=self.cfg.batch_size,
)
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
print('Using CPU memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
)
else: # Sufficient CUDA memory
print('Using CUDA memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
)
def add(self, tds):
"""Add an episode to the buffer. All episodes are expected to have the same length."""
if self._num_eps == 0:
self._buffer = self._init(tds)
self._buffer.add(tds)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
return obs.to(self._device, non_blocking=True), \
action.to(self._device, non_blocking=True), \
reward.to(self._device, non_blocking=True), \
task.to(self._device, non_blocking=True) if task is not None else None
def save(self):
"""Save the buffer to disk. Useful for storing offline datasets."""
td = self._buffer._storage._storage.cpu()
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
torch.save(td, fp)

View File

@@ -1,5 +1,6 @@
import os import os
os.environ['MUJOCO_GL'] = 'egl' os.environ['MUJOCO_GL'] = 'egl'
os.environ['LAZY_LEGACY_OP'] = 0
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import torch import torch
@@ -10,6 +11,7 @@ from termcolor import colored
from common.parser import parse_cfg from common.parser import parse_cfg
from common.seed import set_seed from common.seed import set_seed
from common.buffer import Buffer from common.buffer import Buffer
# from common.legacy_buffer import Buffer
from envs import make_env from envs import make_env
from tdmpc2 import TDMPC2 from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer from trainer.offline_trainer import OfflineTrainer

View File

@@ -14,7 +14,6 @@ class OnlineTrainer(Trainer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._step = 0 self._step = 0
self._ep_idx = 0 self._ep_idx = 0
self._ep_reward = 0
self._start_time = time() self._start_time = time()
def common_metrics(self): def common_metrics(self):
@@ -22,7 +21,6 @@ class OnlineTrainer(Trainer):
return dict( return dict(
step=self._step, step=self._step,
episode=self._ep_idx, episode=self._ep_idx,
episode_reward=self._ep_reward,
total_time=time() - self._start_time, total_time=time() - self._start_time,
) )
@@ -49,24 +47,20 @@ class OnlineTrainer(Trainer):
episode_success=np.nanmean(ep_successes), episode_success=np.nanmean(ep_successes),
) )
def to_td(self, obs, action=None, reward=None, done=None): def to_td(self, obs, action=None, reward=None):
"""Creates a TensorDict for a new episode.""" """Creates a TensorDict for a new episode."""
if isinstance(obs, dict): if isinstance(obs, dict):
obs = TensorDict({k: v for k,v in obs.items()}, batch_size=()).cpu() obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu()
else: else:
obs = obs.cpu() obs = obs.unsqueeze(0).cpu()
if action is None: if action is None:
action = torch.empty_like(self.env.rand_act()) action = torch.empty_like(self.env.rand_act())
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan'))
if done is None:
done = False
done = torch.tensor(done)
td = TensorDict(dict( td = TensorDict(dict(
obs=obs.unsqueeze(0), obs=obs,
action=action.unsqueeze(0), action=action.unsqueeze(0),
reward=reward.unsqueeze(0), reward=reward.unsqueeze(0),
done=done.unsqueeze(0),
), batch_size=(1,)) ), batch_size=(1,))
return td return td
@@ -89,16 +83,15 @@ class OnlineTrainer(Trainer):
if self._step > 0: if self._step > 0:
train_metrics.update( train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'], episode_success=info['success'],
) )
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._ep_idx += 1 self._ep_idx = self.buffer.add(torch.cat(self._tds))
self.buffer.add(torch.cat(self._tds))
obs = self.env.reset() obs = self.env.reset()
self._tds = [self.to_td(obs)] self._tds = [self.to_td(obs)]
self._ep_reward = 0
# Collect experience # Collect experience
if self._step > self.cfg.seed_steps: if self._step > self.cfg.seed_steps:
@@ -106,8 +99,7 @@ class OnlineTrainer(Trainer):
else: else:
action = self.env.rand_act() action = self.env.rand_act()
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward, done)) self._tds.append(self.to_td(obs, action, reward))
self._ep_reward += reward
# Update agent # Update agent
if self._step >= self.cfg.seed_steps: if self._step >= self.cfg.seed_steps: