compare new/old buffers
This commit is contained in:
@@ -2,7 +2,6 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from tensordict.tensordict import TensorDict
|
from tensordict.tensordict import TensorDict
|
||||||
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
||||||
from torchrl.envs import RandomCropTensorDict, Transform, Compose
|
|
||||||
|
|
||||||
from common.logger import make_dir
|
from common.logger import make_dir
|
||||||
from common.samplers import SliceSampler
|
from common.samplers import SliceSampler
|
||||||
@@ -19,7 +18,6 @@ class Buffer():
|
|||||||
self._device = torch.device('cuda')
|
self._device = torch.device('cuda')
|
||||||
self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1)
|
self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1)
|
||||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||||
self._num_steps = 0
|
|
||||||
self._num_eps = 0
|
self._num_eps = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -27,11 +25,6 @@ class Buffer():
|
|||||||
"""Return the capacity of the buffer."""
|
"""Return the capacity of the buffer."""
|
||||||
return self._capacity
|
return self._capacity
|
||||||
|
|
||||||
@property
|
|
||||||
def num_steps(self):
|
|
||||||
"""Return the number of steps in the buffer."""
|
|
||||||
return self._num_steps
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_eps(self):
|
def num_eps(self):
|
||||||
"""Return the number of episodes in the buffer."""
|
"""Return the number of episodes in the buffer."""
|
||||||
@@ -45,7 +38,8 @@ class Buffer():
|
|||||||
storage=storage,
|
storage=storage,
|
||||||
sampler=SliceSampler(
|
sampler=SliceSampler(
|
||||||
slice_len=self.cfg.horizon+1,
|
slice_len=self.cfg.horizon+1,
|
||||||
end_key='done',
|
end_key=None,
|
||||||
|
traj_key='episode',
|
||||||
truncated_key=None,
|
truncated_key=None,
|
||||||
),
|
),
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
@@ -53,12 +47,16 @@ class Buffer():
|
|||||||
batch_size=self.cfg.batch_size,
|
batch_size=self.cfg.batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init(self, td):
|
def _init(self, tds):
|
||||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||||
mem_free, _ = torch.cuda.mem_get_info()
|
mem_free, _ = torch.cuda.mem_get_info()
|
||||||
bytes_per_step = sum([x.numel()*x.element_size() for x in td[0].values()])
|
bytes_per_ep = sum([
|
||||||
print(f'Bytes per step: {bytes_per_step:,}')
|
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||||
total_bytes = bytes_per_step*self._capacity
|
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
||||||
|
for k,v in tds.items()
|
||||||
|
])
|
||||||
|
print(f'Bytes per episode: {bytes_per_ep:,}')
|
||||||
|
total_bytes = bytes_per_ep*self._capacity
|
||||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||||
# Heuristic: decide whether to use CUDA or CPU memory
|
# Heuristic: decide whether to use CUDA or CPU memory
|
||||||
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
|
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
|
||||||
@@ -72,23 +70,20 @@ class Buffer():
|
|||||||
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
|
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
|
||||||
)
|
)
|
||||||
|
|
||||||
def add(self, td):
|
def add(self, tds):
|
||||||
"""Add a step to the buffer."""
|
"""Add a step to the buffer."""
|
||||||
done = bool(td['done'].any())
|
tds['episode'] = torch.ones_like(tds['reward'], dtype=torch.int64) * self._num_eps
|
||||||
if done:
|
tds['step'] = torch.arange(0, len(tds))
|
||||||
|
if self._num_eps == 0:
|
||||||
|
self._buffer = self._init(tds)
|
||||||
|
self._buffer.extend(tds)
|
||||||
self._num_eps += 1
|
self._num_eps += 1
|
||||||
td['episode'] = torch.ones_like(td['done']) * self._num_eps
|
return self._num_eps
|
||||||
td['step'] = torch.arange(0, len(td))
|
|
||||||
if self._num_steps == 0:
|
|
||||||
self._buffer = self._init(td)
|
|
||||||
self._buffer.extend(td)
|
|
||||||
self._num_steps += 1
|
|
||||||
return self._num_steps
|
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
"""Sample a batch of sub-trajectories from the buffer."""
|
"""Sample a batch of sub-trajectories from the buffer."""
|
||||||
td = self._buffer.sample(batch_size=self._batch_size) \
|
td = self._buffer.sample(batch_size=self._batch_size) \
|
||||||
.reshape(-1, self.cfg.horizon+1).permute(1, 0)
|
.view(-1, self.cfg.horizon+1).permute(1, 0)
|
||||||
obs = td['obs'].to(self._device, non_blocking=True)
|
obs = td['obs'].to(self._device, non_blocking=True)
|
||||||
action = td['action'][1:].to(self._device, non_blocking=True)
|
action = td['action'][1:].to(self._device, non_blocking=True)
|
||||||
reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True)
|
reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True)
|
||||||
|
|||||||
115
tdmpc2/common/legacy_buffer.py
Normal file
115
tdmpc2/common/legacy_buffer.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import torch
|
||||||
|
from tensordict.tensordict import TensorDict
|
||||||
|
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
||||||
|
from torchrl.data.replay_buffers.samplers import RandomSampler
|
||||||
|
from torchrl.envs import RandomCropTensorDict, Transform, Compose
|
||||||
|
|
||||||
|
from common.logger import make_dir
|
||||||
|
|
||||||
|
|
||||||
|
class DataPrepTransform(Transform):
|
||||||
|
"""
|
||||||
|
Preprocesses data for TD-MPC2 training.
|
||||||
|
Replay data is expected to be a TensorDict with the following keys:
|
||||||
|
obs: observations
|
||||||
|
action: actions
|
||||||
|
reward: rewards
|
||||||
|
task: task IDs (optional)
|
||||||
|
A TensorDict with T time steps has T+1 observations and T actions and rewards.
|
||||||
|
The first actions and rewards in each TensorDict are dummies and should be ignored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__([])
|
||||||
|
|
||||||
|
def forward(self, td):
|
||||||
|
td = td.permute(1,0)
|
||||||
|
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
|
||||||
|
|
||||||
|
|
||||||
|
class Buffer():
|
||||||
|
"""
|
||||||
|
Create a replay buffer for TD-MPC2 training.
|
||||||
|
Uses CUDA memory if available, and CPU memory otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
self.cfg = cfg
|
||||||
|
self._device = torch.device('cuda')
|
||||||
|
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
|
||||||
|
self._num_eps = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capacity(self):
|
||||||
|
"""Return the capacity of the buffer."""
|
||||||
|
return self._capacity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_eps(self):
|
||||||
|
"""Return the number of episodes in the buffer."""
|
||||||
|
return self._num_eps
|
||||||
|
|
||||||
|
def _reserve_buffer(self, storage):
|
||||||
|
"""
|
||||||
|
Reserve a buffer with the given storage.
|
||||||
|
Uses the RandomSampler to sample trajectories,
|
||||||
|
and the RandomCropTensorDict transform to crop trajectories to the desired length.
|
||||||
|
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
|
||||||
|
"""
|
||||||
|
return ReplayBuffer(
|
||||||
|
storage=storage,
|
||||||
|
sampler=RandomSampler(),
|
||||||
|
pin_memory=True,
|
||||||
|
prefetch=1,
|
||||||
|
transform=Compose(
|
||||||
|
RandomCropTensorDict(self.cfg.horizon+1, -1),
|
||||||
|
DataPrepTransform(),
|
||||||
|
),
|
||||||
|
batch_size=self.cfg.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init(self, tds):
|
||||||
|
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||||
|
mem_free, _ = torch.cuda.mem_get_info()
|
||||||
|
bytes_per_ep = sum([
|
||||||
|
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||||
|
else sum([x.numel()*x.element_size() for x in v.values()])) \
|
||||||
|
for k,v in tds.items()
|
||||||
|
])
|
||||||
|
print(f'Bytes per episode: {bytes_per_ep:,}')
|
||||||
|
total_bytes = bytes_per_ep*self._capacity
|
||||||
|
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||||
|
# Heuristic: decide whether to use CUDA or CPU memory
|
||||||
|
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
|
||||||
|
print('Using CPU memory for storage.')
|
||||||
|
return self._reserve_buffer(
|
||||||
|
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
|
||||||
|
)
|
||||||
|
else: # Sufficient CUDA memory
|
||||||
|
print('Using CUDA memory for storage.')
|
||||||
|
return self._reserve_buffer(
|
||||||
|
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
|
||||||
|
)
|
||||||
|
|
||||||
|
def add(self, tds):
|
||||||
|
"""Add an episode to the buffer. All episodes are expected to have the same length."""
|
||||||
|
if self._num_eps == 0:
|
||||||
|
self._buffer = self._init(tds)
|
||||||
|
self._buffer.add(tds)
|
||||||
|
self._num_eps += 1
|
||||||
|
return self._num_eps
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
"""Sample a batch of sub-trajectories from the buffer."""
|
||||||
|
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
|
||||||
|
return obs.to(self._device, non_blocking=True), \
|
||||||
|
action.to(self._device, non_blocking=True), \
|
||||||
|
reward.to(self._device, non_blocking=True), \
|
||||||
|
task.to(self._device, non_blocking=True) if task is not None else None
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""Save the buffer to disk. Useful for storing offline datasets."""
|
||||||
|
td = self._buffer._storage._storage.cpu()
|
||||||
|
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
|
||||||
|
torch.save(td, fp)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
os.environ['MUJOCO_GL'] = 'egl'
|
os.environ['MUJOCO_GL'] = 'egl'
|
||||||
|
os.environ['LAZY_LEGACY_OP'] = 0
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import torch
|
import torch
|
||||||
@@ -10,6 +11,7 @@ from termcolor import colored
|
|||||||
from common.parser import parse_cfg
|
from common.parser import parse_cfg
|
||||||
from common.seed import set_seed
|
from common.seed import set_seed
|
||||||
from common.buffer import Buffer
|
from common.buffer import Buffer
|
||||||
|
# from common.legacy_buffer import Buffer
|
||||||
from envs import make_env
|
from envs import make_env
|
||||||
from tdmpc2 import TDMPC2
|
from tdmpc2 import TDMPC2
|
||||||
from trainer.offline_trainer import OfflineTrainer
|
from trainer.offline_trainer import OfflineTrainer
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ class OnlineTrainer(Trainer):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._step = 0
|
self._step = 0
|
||||||
self._ep_idx = 0
|
self._ep_idx = 0
|
||||||
self._ep_reward = 0
|
|
||||||
self._start_time = time()
|
self._start_time = time()
|
||||||
|
|
||||||
def common_metrics(self):
|
def common_metrics(self):
|
||||||
@@ -22,7 +21,6 @@ class OnlineTrainer(Trainer):
|
|||||||
return dict(
|
return dict(
|
||||||
step=self._step,
|
step=self._step,
|
||||||
episode=self._ep_idx,
|
episode=self._ep_idx,
|
||||||
episode_reward=self._ep_reward,
|
|
||||||
total_time=time() - self._start_time,
|
total_time=time() - self._start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,24 +47,20 @@ class OnlineTrainer(Trainer):
|
|||||||
episode_success=np.nanmean(ep_successes),
|
episode_success=np.nanmean(ep_successes),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_td(self, obs, action=None, reward=None, done=None):
|
def to_td(self, obs, action=None, reward=None):
|
||||||
"""Creates a TensorDict for a new episode."""
|
"""Creates a TensorDict for a new episode."""
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
obs = TensorDict({k: v for k,v in obs.items()}, batch_size=()).cpu()
|
obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu()
|
||||||
else:
|
else:
|
||||||
obs = obs.cpu()
|
obs = obs.unsqueeze(0).cpu()
|
||||||
if action is None:
|
if action is None:
|
||||||
action = torch.empty_like(self.env.rand_act())
|
action = torch.empty_like(self.env.rand_act())
|
||||||
if reward is None:
|
if reward is None:
|
||||||
reward = torch.tensor(float('nan'))
|
reward = torch.tensor(float('nan'))
|
||||||
if done is None:
|
|
||||||
done = False
|
|
||||||
done = torch.tensor(done)
|
|
||||||
td = TensorDict(dict(
|
td = TensorDict(dict(
|
||||||
obs=obs.unsqueeze(0),
|
obs=obs,
|
||||||
action=action.unsqueeze(0),
|
action=action.unsqueeze(0),
|
||||||
reward=reward.unsqueeze(0),
|
reward=reward.unsqueeze(0),
|
||||||
done=done.unsqueeze(0),
|
|
||||||
), batch_size=(1,))
|
), batch_size=(1,))
|
||||||
return td
|
return td
|
||||||
|
|
||||||
@@ -89,16 +83,15 @@ class OnlineTrainer(Trainer):
|
|||||||
|
|
||||||
if self._step > 0:
|
if self._step > 0:
|
||||||
train_metrics.update(
|
train_metrics.update(
|
||||||
|
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||||
episode_success=info['success'],
|
episode_success=info['success'],
|
||||||
)
|
)
|
||||||
train_metrics.update(self.common_metrics())
|
train_metrics.update(self.common_metrics())
|
||||||
self.logger.log(train_metrics, 'train')
|
self.logger.log(train_metrics, 'train')
|
||||||
self._ep_idx += 1
|
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||||
self.buffer.add(torch.cat(self._tds))
|
|
||||||
|
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
self._tds = [self.to_td(obs)]
|
self._tds = [self.to_td(obs)]
|
||||||
self._ep_reward = 0
|
|
||||||
|
|
||||||
# Collect experience
|
# Collect experience
|
||||||
if self._step > self.cfg.seed_steps:
|
if self._step > self.cfg.seed_steps:
|
||||||
@@ -106,8 +99,7 @@ class OnlineTrainer(Trainer):
|
|||||||
else:
|
else:
|
||||||
action = self.env.rand_act()
|
action = self.env.rand_act()
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
self._tds.append(self.to_td(obs, action, reward, done))
|
self._tds.append(self.to_td(obs, action, reward))
|
||||||
self._ep_reward += reward
|
|
||||||
|
|
||||||
# Update agent
|
# Update agent
|
||||||
if self._step >= self.cfg.seed_steps:
|
if self._step >= self.cfg.seed_steps:
|
||||||
|
|||||||
Reference in New Issue
Block a user