From 3ded0ebc83bbc5480ccba9aab5768688a0c38542 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Fri, 22 Dec 2023 05:55:43 -0800 Subject: [PATCH 01/10] faster replay buffer implementation --- tdmpc2/common/buffer.py | 91 ++++---- tdmpc2/common/samplers.py | 365 +++++++++++++++++++++++++++++++ tdmpc2/config.yaml | 8 +- tdmpc2/envs/__init__.py | 8 +- tdmpc2/trainer/online_trainer.py | 24 +- 5 files changed, 428 insertions(+), 68 deletions(-) create mode 100644 tdmpc2/common/samplers.py diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index dbbfea6..10b74b2 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -2,30 +2,10 @@ from pathlib import Path import torch from tensordict.tensordict import TensorDict from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage -from torchrl.data.replay_buffers.samplers import RandomSampler from torchrl.envs import RandomCropTensorDict, Transform, Compose from common.logger import make_dir - - -class DataPrepTransform(Transform): - """ - Preprocesses data for TD-MPC2 training. - Replay data is expected to be a TensorDict with the following keys: - obs: observations - action: actions - reward: rewards - task: task IDs (optional) - A TensorDict with T time steps has T+1 observations and T actions and rewards. - The first actions and rewards in each TensorDict are dummies and should be ignored. - """ - - def __init__(self): - super().__init__([]) - - def forward(self, td): - td = td.permute(1,0) - return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None) +from common.samplers import SliceSampler class Buffer(): @@ -37,7 +17,9 @@ class Buffer(): def __init__(self, cfg): self.cfg = cfg self._device = torch.device('cuda') - self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length + self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1) + self._capacity = min(cfg.buffer_size, cfg.steps) + self._num_steps = 0 self._num_eps = 0 @property @@ -45,6 +27,11 @@ class Buffer(): """Return the capacity of the buffer.""" return self._capacity + @property + def num_steps(self): + """Return the number of steps in the buffer.""" + return self._num_steps + @property def num_eps(self): """Return the number of episodes in the buffer.""" @@ -53,32 +40,25 @@ class Buffer(): def _reserve_buffer(self, storage): """ Reserve a buffer with the given storage. - Uses the RandomSampler to sample trajectories, - and the RandomCropTensorDict transform to crop trajectories to the desired length. - DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates. """ return ReplayBuffer( storage=storage, - sampler=RandomSampler(), - pin_memory=True, - prefetch=1, - transform=Compose( - RandomCropTensorDict(self.cfg.horizon+1, -1), - DataPrepTransform(), + sampler=SliceSampler( + slice_len=self.cfg.horizon+1, + end_key='done', + truncated_key=None, ), + pin_memory=True, + prefetch=2, batch_size=self.cfg.batch_size, ) - def _init(self, tds): + def _init(self, td): """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" mem_free, _ = torch.cuda.mem_get_info() - bytes_per_ep = sum([ - (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ - else sum([x.numel()*x.element_size() for x in v.values()])) \ - for k,v in tds.items() - ]) - print(f'Bytes per episode: {bytes_per_ep:,}') - total_bytes = bytes_per_ep*self._capacity + bytes_per_step = sum([x.numel()*x.element_size() for x in td[0].values()]) + print(f'Bytes per step: {bytes_per_step:,}') + total_bytes = bytes_per_step*self._capacity print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory if 2.5*total_bytes > mem_free: # Insufficient CUDA memory @@ -92,22 +72,29 @@ class Buffer(): LazyTensorStorage(self._capacity, device=torch.device('cuda')) ) - def add(self, tds): - """Add an episode to the buffer. All episodes are expected to have the same length.""" - if self._num_eps == 0: - self._buffer = self._init(tds) - self._buffer.add(tds) - self._num_eps += 1 - return self._num_eps + def add(self, td): + """Add a step to the buffer.""" + done = bool(td['done'].any()) + if done: + self._num_eps +=1 + td['episode'] = torch.ones_like(td['done']) * self._num_eps + td['step'] = torch.arange(0, len(td)) + if self._num_steps == 0: + self._buffer = self._init(td) + self._buffer.extend(td) + self._num_steps += 1 + return self._num_steps def sample(self): """Sample a batch of sub-trajectories from the buffer.""" - obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size) - return obs.to(self._device, non_blocking=True), \ - action.to(self._device, non_blocking=True), \ - reward.to(self._device, non_blocking=True), \ - task.to(self._device, non_blocking=True) if task is not None else None - + td = self._buffer.sample(batch_size=self._batch_size) \ + .reshape(-1, self.cfg.horizon+1).permute(1, 0) + obs = td['obs'].to(self._device, non_blocking=True) + action = td['action'][1:].to(self._device, non_blocking=True) + reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True) + task = td['task'][0].to(self._device, non_blocking=True) if 'task' in td.keys() else None + return obs, action, reward, task + def save(self): """Save the buffer to disk. Useful for storing offline datasets.""" td = self._buffer._storage._storage.cpu() diff --git a/tdmpc2/common/samplers.py b/tdmpc2/common/samplers.py new file mode 100644 index 0000000..af0e073 --- /dev/null +++ b/tdmpc2/common/samplers.py @@ -0,0 +1,365 @@ +from __future__ import annotations + +import json +import warnings +from abc import ABC, abstractmethod +from copy import copy, deepcopy +from multiprocessing.context import get_spawning_popen +from pathlib import Path +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch + +from tensordict import MemoryMappedTensor +from tensordict.utils import NestedKey + +from torchrl._extension import EXTENSION_WARNING + +try: + from torchrl._torchrl import ( + MinSegmentTreeFp32, + MinSegmentTreeFp64, + SumSegmentTreeFp32, + SumSegmentTreeFp64, + ) +except ImportError: + warnings.warn(EXTENSION_WARNING) + +from torchrl.data.replay_buffers.storages import Storage, TensorStorage +from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES +from torchrl.data.replay_buffers.samplers import Sampler + +_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage." + + +class SliceSampler(Sampler): + """Samples slices of data along the first dimension, given start and stop signals. + + This class samples sub-trajectories with replacement. For a version without + replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`. + + Keyword Args: + num_slices (int): the number of slices to be sampled. The batch-size + must be greater or equal to the ``num_slices`` argument. Exclusive + with ``slice_len``. + slice_len (int): the length of the slices to be sampled. The batch-size + must be greater or equal to the ``slice_len`` argument and divisible + by it. Exclusive with ``num_slices``. + end_key (NestedKey, optional): the key indicating the end of a + trajectory (or episode). Defaults to ``("next", "done")``. + traj_key (NestedKey, optional): the key indicating the trajectories. + Defaults to ``"episode"`` (commonly used across datasets in TorchRL). + cache_values (bool, optional): to be used with static datasets. + Will cache the start and end signal of the trajectory. + truncated_key (NestedKey, optional): If not ``None``, this argument + indicates where a truncated signal should be written in the output + data. This is used to indicate to value estimators where the provided + trajectory breaks. Defaults to ``("next", "truncated")``. + This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` + instances (otherwise the truncated key is returned in the info dictionary + returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method). + strict_length (bool, optional): if ``False``, trajectories of length + shorter than `slice_len` (or `batch_size // num_slices`) will be + allowed to appear in the batch. + Be mindful that this can result in effective `batch_size` shorter + than the one asked for! Trajectories can be split using + :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. + + .. note:: To recover the trajectory splits in the storage, + :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first + attempt to find the ``traj_key`` entry in the storage. If it cannot be + found, the ``end_key`` will be used to reconstruct the episodes. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer + >>> from torchrl.data.replay_buffers.samplers import SliceSampler + >>> torch.manual_seed(0) + >>> rb = TensorDictReplayBuffer( + ... storage=LazyMemmapStorage(1_000_000), + ... sampler=SliceSampler(cache_values=True, num_slices=10), + ... batch_size=320, + ... ) + >>> episode = torch.zeros(1000, dtype=torch.int) + >>> episode[:300] = 1 + >>> episode[300:550] = 2 + >>> episode[550:700] = 3 + >>> episode[700:] = 4 + >>> data = TensorDict( + ... { + ... "episode": episode, + ... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5), + ... "act": torch.randn((20,)).expand(1000, 20), + ... "other": torch.randn((20, 50)).expand(1000, 20, 50), + ... }, [1000] + ... ) + >>> rb.extend(data) + >>> sample = rb.sample() + >>> print("sample:", sample) + >>> print("episodes", sample.get("episode").unique()) + episodes tensor([1, 2, 3, 4], dtype=torch.int32) + + :class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with + most of TorchRL's datasets: + + Examples: + >>> import torch + >>> + >>> from torchrl.data.datasets import RobosetExperienceReplay + >>> from torchrl.data import SliceSampler + >>> + >>> torch.manual_seed(0) + >>> num_slices = 10 + >>> dataid = list(RobosetExperienceReplay.available_datasets)[0] + >>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices)) + >>> for batch in data: + ... batch = batch.reshape(num_slices, -1) + ... break + >>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1)) + check that each batch only has one episode: tensor([[19], + [14], + [ 8], + [10], + [13], + [ 4], + [ 2], + [ 3], + [22], + [ 8]]) + + """ + + def __init__( + self, + *, + num_slices: int = None, + slice_len: int = None, + end_key: NestedKey | None = None, + traj_key: NestedKey | None = None, + cache_values: bool = False, + truncated_key: NestedKey | None = ("next", "truncated"), + strict_length: bool = True, + ) -> object: + if end_key is None: + end_key = ("next", "done") + if traj_key is None: + traj_key = "episode" + if not ((num_slices is None) ^ (slice_len is None)): + raise TypeError( + "Either num_slices or slice_len must be not None, and not both. " + f"Got num_slices={num_slices} and slice_len={slice_len}." + ) + self.num_slices = num_slices + self.slice_len = slice_len + self.end_key = end_key + self.traj_key = traj_key + self.truncated_key = truncated_key + self.cache_values = cache_values + self._fetch_traj = True + self._uses_data_prefix = False + self.strict_length = strict_length + self._cache = {} + + @staticmethod + def _find_start_stop_traj(*, trajectory=None, end=None): + if trajectory is not None: + # slower + # _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True) + # stop_idx = stop_idx.cumsum(0) - 1 + + # even slower + # t = trajectory.unsqueeze(0) + # w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2) + # stop_idx = torch.conv1d(t, w).nonzero() + + # faster + end = trajectory[:-1] != trajectory[1:] + end = torch.cat([end, torch.ones_like(end[:1])], 0) + else: + end = torch.index_fill( + end, + index=torch.tensor(-1, device=end.device, dtype=torch.long), + dim=0, + value=1, + ) + if end.ndim != 1: + raise RuntimeError( + f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead." + ) + stop_idx = end.view(-1).nonzero().view(-1) + start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1]) + lengths = stop_idx - start_idx + 1 + return start_idx, stop_idx, lengths + + def _tensor_slices_from_startend(self, seq_length, start): + if isinstance(seq_length, int): + return ( + torch.arange( + seq_length, device=start.device, dtype=start.dtype + ).unsqueeze(0) + + start.unsqueeze(1) + ).view(-1) + else: + # when padding is needed + return torch.cat( + [ + _start + + torch.arange(_seq_len, device=start.device, dtype=start.dtype) + for _start, _seq_len in zip(start, seq_length) + ] + ) + + def _get_stop_and_length(self, storage, fallback=True): + if self.cache_values and "stop-and-length" in self._cache: + return self._cache.get("stop-and-length") + + if self._fetch_traj: + # We first try with the traj_key + try: + # In some cases, the storage hides the data behind "_data". + # In the future, this may be deprecated, and we don't want to mess + # with the keys provided by the user so we fall back on a proxy to + # the traj key. + try: + trajectory = storage._storage.get(self._used_traj_key) + except KeyError: + trajectory = storage._storage.get(("_data", self.traj_key)) + # cache that value for future use + self._used_traj_key = ("_data", self.traj_key) + self._uses_data_prefix = ( + isinstance(self._used_traj_key, tuple) + and self._used_traj_key[0] == "_data" + ) + vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)]) + return self._cache.setdefault("stop-and-length", vals) + except KeyError: + if fallback: + self._fetch_traj = False + return self._get_stop_and_length(storage, fallback=False) + raise + + else: + try: + # In some cases, the storage hides the data behind "_data". + # In the future, this may be deprecated, and we don't want to mess + # with the keys provided by the user so we fall back on a proxy to + # the traj key. + try: + done = storage._storage.get(self._used_end_key) + except KeyError: + done = storage._storage.get(("_data", self.end_key)) + # cache that value for future use + self._used_end_key = ("_data", self.end_key) + self._uses_data_prefix = ( + isinstance(self._used_end_key, tuple) + and self._used_end_key[0] == "_data" + ) + vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] + return self._cache.setdefault("stop-and-length", vals) + except KeyError: + if fallback: + self._fetch_traj = True + return self._get_stop_and_length(storage, fallback=False) + raise + + def _adjusted_batch_size(self, batch_size): + if self.num_slices is not None: + if batch_size % self.num_slices != 0: + raise RuntimeError( + f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}." + ) + seq_length = batch_size // self.num_slices + num_slices = self.num_slices + else: + if batch_size % self.slice_len != 0: + raise RuntimeError( + f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}." + ) + seq_length = self.slice_len + num_slices = batch_size // self.slice_len + return seq_length, num_slices + + def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + if not isinstance(storage, TensorStorage): + raise RuntimeError( + f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead." + ) + + # pick up as many trajs as we need + start_idx, stop_idx, lengths = self._get_stop_and_length(storage) + seq_length, num_slices = self._adjusted_batch_size(batch_size) + return self._sample_slices(lengths, start_idx, seq_length, num_slices) + + def _sample_slices( + self, lengths, start_idx, seq_length, num_slices, traj_idx=None + ) -> Tuple[torch.Tensor, dict]: + if (lengths < seq_length).any(): + if self.strict_length: + raise RuntimeError( + "Some stored trajectories have a length shorter than the slice that was asked for. " + "Create the sampler with `strict_length=False` to allow shorter trajectories to appear " + "in you batch." + ) + # make seq_length a tensor with values clamped by lengths + seq_length = lengths.clamp_max(seq_length) + + if traj_idx is None: + traj_idx = torch.randint( + lengths.shape[0], (num_slices,), device=lengths.device + ) + else: + num_slices = traj_idx.shape[0] + relative_starts = ( + ( + torch.rand(num_slices, device=lengths.device) + * (lengths[traj_idx] - seq_length) + ) + .floor() + .to(start_idx.dtype) + ) + starts = start_idx[traj_idx] + relative_starts + index = self._tensor_slices_from_startend(seq_length, starts) + if self.truncated_key is not None: + truncated_key = self.truncated_key + + truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device) + if isinstance(seq_length, int): + truncated.view(num_slices, -1)[:, -1] = 1 + else: + truncated[seq_length.cumsum(0) - 1] = 1 + return index.to(torch.long), {truncated_key: truncated} + return index.to(torch.long), {} + + @property + def _used_traj_key(self): + return self.__dict__.get("__used_traj_key", self.traj_key) + + @_used_traj_key.setter + def _used_traj_key(self, value): + self.__dict__["__used_traj_key"] = value + + @property + def _used_end_key(self): + return self.__dict__.get("__used_end_key", self.end_key) + + @_used_end_key.setter + def _used_end_key(self, value): + self.__dict__["__used_end_key"] = value + + def _empty(self): + pass + + def dumps(self, path): + # no op - cache does not need to be saved + ... + + def loads(self, path): + # no op + ... + + def __getstate__(self): + state = copy(self.__dict__) + state["_cache"] = {} + return state diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index b625bf5..083bdcf 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -60,11 +60,11 @@ dropout: 0.01 simnorm_dim: 8 # logging -wandb_project: ??? -wandb_entity: ??? +wandb_project: tdmpcv2 +wandb_entity: nicklashansen wandb_silent: false -disable_wandb: true -save_csv: true +disable_wandb: false +save_csv: false # misc save_video: true diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index ef2a630..0d78d27 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -6,9 +6,9 @@ import gym from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.tensor import TensorWrapper from envs.dmcontrol import make_env as make_dm_control_env -from envs.maniskill import make_env as make_maniskill_env -from envs.metaworld import make_env as make_metaworld_env -from envs.myosuite import make_env as make_myosuite_env +# from envs.maniskill import make_env as make_maniskill_env +# from envs.metaworld import make_env as make_metaworld_env +# from envs.myosuite import make_env as make_myosuite_env from envs.exceptions import UnknownTaskError warnings.filterwarnings('ignore', category=DeprecationWarning) @@ -44,7 +44,7 @@ def make_env(cfg): env = make_multitask_env(cfg) else: env = None - for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: + for fn in [make_dm_control_env]: #, make_maniskill_env, make_metaworld_env, make_myosuite_env]: try: env = fn(cfg) except UnknownTaskError: diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 94835ca..52d92d8 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -14,6 +14,7 @@ class OnlineTrainer(Trainer): super().__init__(*args, **kwargs) self._step = 0 self._ep_idx = 0 + self._ep_reward = 0 self._start_time = time() def common_metrics(self): @@ -21,6 +22,7 @@ class OnlineTrainer(Trainer): return dict( step=self._step, episode=self._ep_idx, + episode_reward=self._ep_reward, total_time=time() - self._start_time, ) @@ -47,22 +49,26 @@ class OnlineTrainer(Trainer): episode_success=np.nanmean(ep_successes), ) - def to_td(self, obs, action=None, reward=None): + def to_td(self, obs, action=None, reward=None, done=None): """Creates a TensorDict for a new episode.""" if isinstance(obs, dict): - obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu() + obs = TensorDict({k: v for k,v in obs.items()}, batch_size=()).cpu() else: - obs = obs.unsqueeze(0).cpu() + obs = obs.cpu() if action is None: action = torch.empty_like(self.env.rand_act()) if reward is None: reward = torch.tensor(float('nan')) + if done is None: + done = False + done = torch.tensor(done) td = TensorDict(dict( - obs=obs, + obs=obs.unsqueeze(0), action=action.unsqueeze(0), reward=reward.unsqueeze(0), + done=done.unsqueeze(0), ), batch_size=(1,)) - return td + return td def train(self): """Train a TD-MPC2 agent.""" @@ -83,15 +89,16 @@ class OnlineTrainer(Trainer): if self._step > 0: train_metrics.update( - episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), episode_success=info['success'], ) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train') - self._ep_idx = self.buffer.add(torch.cat(self._tds)) + self._ep_idx += 1 + self.buffer.add(torch.cat(self._tds)) obs = self.env.reset() self._tds = [self.to_td(obs)] + self._ep_reward = 0 # Collect experience if self._step > self.cfg.seed_steps: @@ -99,7 +106,8 @@ class OnlineTrainer(Trainer): else: action = self.env.rand_act() obs, reward, done, info = self.env.step(action) - self._tds.append(self.to_td(obs, action, reward)) + self._tds.append(self.to_td(obs, action, reward, done)) + self._ep_reward += reward # Update agent if self._step >= self.cfg.seed_steps: From fea0936e6972830ac07fe9b420d32045b88a420c Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Fri, 22 Dec 2023 07:44:34 -0800 Subject: [PATCH 02/10] set logging defaults --- tdmpc2/config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index ae98d43..b720923 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -62,11 +62,11 @@ dropout: 0.01 simnorm_dim: 8 # logging -wandb_project: tdmpcv2 -wandb_entity: nicklashansen +wandb_project: ??? +wandb_entity: ??? wandb_silent: false -disable_wandb: false -save_csv: false +disable_wandb: true +save_csv: true # misc save_video: true From 34ea3662cd7b7ecce92415b68693a745b899925c Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Fri, 22 Dec 2023 13:34:12 -0800 Subject: [PATCH 03/10] compare new/old buffers --- tdmpc2/common/buffer.py | 43 +++++------- tdmpc2/common/legacy_buffer.py | 115 +++++++++++++++++++++++++++++++ tdmpc2/train.py | 2 + tdmpc2/trainer/online_trainer.py | 24 +++---- 4 files changed, 144 insertions(+), 40 deletions(-) create mode 100644 tdmpc2/common/legacy_buffer.py diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 10b74b2..aae747a 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -2,7 +2,6 @@ from pathlib import Path import torch from tensordict.tensordict import TensorDict from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage -from torchrl.envs import RandomCropTensorDict, Transform, Compose from common.logger import make_dir from common.samplers import SliceSampler @@ -19,7 +18,6 @@ class Buffer(): self._device = torch.device('cuda') self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1) self._capacity = min(cfg.buffer_size, cfg.steps) - self._num_steps = 0 self._num_eps = 0 @property @@ -27,11 +25,6 @@ class Buffer(): """Return the capacity of the buffer.""" return self._capacity - @property - def num_steps(self): - """Return the number of steps in the buffer.""" - return self._num_steps - @property def num_eps(self): """Return the number of episodes in the buffer.""" @@ -45,7 +38,8 @@ class Buffer(): storage=storage, sampler=SliceSampler( slice_len=self.cfg.horizon+1, - end_key='done', + end_key=None, + traj_key='episode', truncated_key=None, ), pin_memory=True, @@ -53,12 +47,16 @@ class Buffer(): batch_size=self.cfg.batch_size, ) - def _init(self, td): + def _init(self, tds): """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" mem_free, _ = torch.cuda.mem_get_info() - bytes_per_step = sum([x.numel()*x.element_size() for x in td[0].values()]) - print(f'Bytes per step: {bytes_per_step:,}') - total_bytes = bytes_per_step*self._capacity + bytes_per_ep = sum([ + (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ + else sum([x.numel()*x.element_size() for x in v.values()])) \ + for k,v in tds.items() + ]) + print(f'Bytes per episode: {bytes_per_ep:,}') + total_bytes = bytes_per_ep*self._capacity print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory if 2.5*total_bytes > mem_free: # Insufficient CUDA memory @@ -72,23 +70,20 @@ class Buffer(): LazyTensorStorage(self._capacity, device=torch.device('cuda')) ) - def add(self, td): + def add(self, tds): """Add a step to the buffer.""" - done = bool(td['done'].any()) - if done: - self._num_eps +=1 - td['episode'] = torch.ones_like(td['done']) * self._num_eps - td['step'] = torch.arange(0, len(td)) - if self._num_steps == 0: - self._buffer = self._init(td) - self._buffer.extend(td) - self._num_steps += 1 - return self._num_steps + tds['episode'] = torch.ones_like(tds['reward'], dtype=torch.int64) * self._num_eps + tds['step'] = torch.arange(0, len(tds)) + if self._num_eps == 0: + self._buffer = self._init(tds) + self._buffer.extend(tds) + self._num_eps += 1 + return self._num_eps def sample(self): """Sample a batch of sub-trajectories from the buffer.""" td = self._buffer.sample(batch_size=self._batch_size) \ - .reshape(-1, self.cfg.horizon+1).permute(1, 0) + .view(-1, self.cfg.horizon+1).permute(1, 0) obs = td['obs'].to(self._device, non_blocking=True) action = td['action'][1:].to(self._device, non_blocking=True) reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True) diff --git a/tdmpc2/common/legacy_buffer.py b/tdmpc2/common/legacy_buffer.py new file mode 100644 index 0000000..dbbfea6 --- /dev/null +++ b/tdmpc2/common/legacy_buffer.py @@ -0,0 +1,115 @@ +from pathlib import Path +import torch +from tensordict.tensordict import TensorDict +from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage +from torchrl.data.replay_buffers.samplers import RandomSampler +from torchrl.envs import RandomCropTensorDict, Transform, Compose + +from common.logger import make_dir + + +class DataPrepTransform(Transform): + """ + Preprocesses data for TD-MPC2 training. + Replay data is expected to be a TensorDict with the following keys: + obs: observations + action: actions + reward: rewards + task: task IDs (optional) + A TensorDict with T time steps has T+1 observations and T actions and rewards. + The first actions and rewards in each TensorDict are dummies and should be ignored. + """ + + def __init__(self): + super().__init__([]) + + def forward(self, td): + td = td.permute(1,0) + return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None) + + +class Buffer(): + """ + Create a replay buffer for TD-MPC2 training. + Uses CUDA memory if available, and CPU memory otherwise. + """ + + def __init__(self, cfg): + self.cfg = cfg + self._device = torch.device('cuda') + self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length + self._num_eps = 0 + + @property + def capacity(self): + """Return the capacity of the buffer.""" + return self._capacity + + @property + def num_eps(self): + """Return the number of episodes in the buffer.""" + return self._num_eps + + def _reserve_buffer(self, storage): + """ + Reserve a buffer with the given storage. + Uses the RandomSampler to sample trajectories, + and the RandomCropTensorDict transform to crop trajectories to the desired length. + DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates. + """ + return ReplayBuffer( + storage=storage, + sampler=RandomSampler(), + pin_memory=True, + prefetch=1, + transform=Compose( + RandomCropTensorDict(self.cfg.horizon+1, -1), + DataPrepTransform(), + ), + batch_size=self.cfg.batch_size, + ) + + def _init(self, tds): + """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" + mem_free, _ = torch.cuda.mem_get_info() + bytes_per_ep = sum([ + (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ + else sum([x.numel()*x.element_size() for x in v.values()])) \ + for k,v in tds.items() + ]) + print(f'Bytes per episode: {bytes_per_ep:,}') + total_bytes = bytes_per_ep*self._capacity + print(f'Storage required: {total_bytes/1e9:.2f} GB') + # Heuristic: decide whether to use CUDA or CPU memory + if 2.5*total_bytes > mem_free: # Insufficient CUDA memory + print('Using CPU memory for storage.') + return self._reserve_buffer( + LazyTensorStorage(self._capacity, device=torch.device('cpu')) + ) + else: # Sufficient CUDA memory + print('Using CUDA memory for storage.') + return self._reserve_buffer( + LazyTensorStorage(self._capacity, device=torch.device('cuda')) + ) + + def add(self, tds): + """Add an episode to the buffer. All episodes are expected to have the same length.""" + if self._num_eps == 0: + self._buffer = self._init(tds) + self._buffer.add(tds) + self._num_eps += 1 + return self._num_eps + + def sample(self): + """Sample a batch of sub-trajectories from the buffer.""" + obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size) + return obs.to(self._device, non_blocking=True), \ + action.to(self._device, non_blocking=True), \ + reward.to(self._device, non_blocking=True), \ + task.to(self._device, non_blocking=True) if task is not None else None + + def save(self): + """Save the buffer to disk. Useful for storing offline datasets.""" + td = self._buffer._storage._storage.cpu() + fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt' + torch.save(td, fp) diff --git a/tdmpc2/train.py b/tdmpc2/train.py index a35c11b..cf9dbb8 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -1,5 +1,6 @@ import os os.environ['MUJOCO_GL'] = 'egl' +os.environ['LAZY_LEGACY_OP'] = 0 import warnings warnings.filterwarnings('ignore') import torch @@ -10,6 +11,7 @@ from termcolor import colored from common.parser import parse_cfg from common.seed import set_seed from common.buffer import Buffer +# from common.legacy_buffer import Buffer from envs import make_env from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 52d92d8..94835ca 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -14,7 +14,6 @@ class OnlineTrainer(Trainer): super().__init__(*args, **kwargs) self._step = 0 self._ep_idx = 0 - self._ep_reward = 0 self._start_time = time() def common_metrics(self): @@ -22,7 +21,6 @@ class OnlineTrainer(Trainer): return dict( step=self._step, episode=self._ep_idx, - episode_reward=self._ep_reward, total_time=time() - self._start_time, ) @@ -49,26 +47,22 @@ class OnlineTrainer(Trainer): episode_success=np.nanmean(ep_successes), ) - def to_td(self, obs, action=None, reward=None, done=None): + def to_td(self, obs, action=None, reward=None): """Creates a TensorDict for a new episode.""" if isinstance(obs, dict): - obs = TensorDict({k: v for k,v in obs.items()}, batch_size=()).cpu() + obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu() else: - obs = obs.cpu() + obs = obs.unsqueeze(0).cpu() if action is None: action = torch.empty_like(self.env.rand_act()) if reward is None: reward = torch.tensor(float('nan')) - if done is None: - done = False - done = torch.tensor(done) td = TensorDict(dict( - obs=obs.unsqueeze(0), + obs=obs, action=action.unsqueeze(0), reward=reward.unsqueeze(0), - done=done.unsqueeze(0), ), batch_size=(1,)) - return td + return td def train(self): """Train a TD-MPC2 agent.""" @@ -89,16 +83,15 @@ class OnlineTrainer(Trainer): if self._step > 0: train_metrics.update( + episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), episode_success=info['success'], ) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train') - self._ep_idx += 1 - self.buffer.add(torch.cat(self._tds)) + self._ep_idx = self.buffer.add(torch.cat(self._tds)) obs = self.env.reset() self._tds = [self.to_td(obs)] - self._ep_reward = 0 # Collect experience if self._step > self.cfg.seed_steps: @@ -106,8 +99,7 @@ class OnlineTrainer(Trainer): else: action = self.env.rand_act() obs, reward, done, info = self.env.step(action) - self._tds.append(self.to_td(obs, action, reward, done)) - self._ep_reward += reward + self._tds.append(self.to_td(obs, action, reward)) # Update agent if self._step >= self.cfg.seed_steps: From 2929cfdb44410b3af58642026145923ed0a319a6 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Fri, 22 Dec 2023 13:57:01 -0800 Subject: [PATCH 04/10] fix computation of mem requirements --- tdmpc2/common/buffer.py | 6 +++--- tdmpc2/train.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index aae747a..7010549 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -53,10 +53,10 @@ class Buffer(): bytes_per_ep = sum([ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ else sum([x.numel()*x.element_size() for x in v.values()])) \ - for k,v in tds.items() - ]) + for v in tds.values() + ]) print(f'Bytes per episode: {bytes_per_ep:,}') - total_bytes = bytes_per_ep*self._capacity + total_bytes = bytes_per_ep * (self._capacity // self.cfg.episode_length) print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory if 2.5*total_bytes > mem_free: # Insufficient CUDA memory diff --git a/tdmpc2/train.py b/tdmpc2/train.py index cf9dbb8..ded21e3 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -1,6 +1,6 @@ import os os.environ['MUJOCO_GL'] = 'egl' -os.environ['LAZY_LEGACY_OP'] = 0 +os.environ['LAZY_LEGACY_OP'] = '0' import warnings warnings.filterwarnings('ignore') import torch From 70fe242adc23cc19d6f6e466bcc423f7f6523c21 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Fri, 22 Dec 2023 14:26:48 -0800 Subject: [PATCH 05/10] does not reproduce results w/ previous buffer --- tdmpc2/common/buffer.py | 4 ++-- tdmpc2/trainer/online_trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 7010549..8d914a7 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -37,13 +37,13 @@ class Buffer(): return ReplayBuffer( storage=storage, sampler=SliceSampler( - slice_len=self.cfg.horizon+1, + num_slices=self.cfg.batch_size, end_key=None, traj_key='episode', truncated_key=None, ), pin_memory=True, - prefetch=2, + prefetch=1, batch_size=self.cfg.batch_size, ) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 94835ca..f5f65cc 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -50,7 +50,7 @@ class OnlineTrainer(Trainer): def to_td(self, obs, action=None, reward=None): """Creates a TensorDict for a new episode.""" if isinstance(obs, dict): - obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu() + obs = TensorDict(obs, batch_size=(), device='cpu') else: obs = obs.unsqueeze(0).cpu() if action is None: From eef1d1b407cfa38f5168b4f11ece4019a741e444 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sat, 23 Dec 2023 07:43:28 -0800 Subject: [PATCH 06/10] unit test buffer implementations --- tdmpc2/common/buffer.py | 142 ++++++++++++++++++++++----------- tdmpc2/common/legacy_buffer.py | 115 -------------------------- tdmpc2/test_buffer.py | 64 +++++++++++++++ tdmpc2/train.py | 6 +- 4 files changed, 164 insertions(+), 163 deletions(-) delete mode 100644 tdmpc2/common/legacy_buffer.py create mode 100644 tdmpc2/test_buffer.py diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 8d914a7..1b9330c 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -1,24 +1,27 @@ -from pathlib import Path import torch from tensordict.tensordict import TensorDict from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage +from torchrl.data.replay_buffers.samplers import RandomSampler +from torchrl.envs import RandomCropTensorDict -from common.logger import make_dir from common.samplers import SliceSampler class Buffer(): """ - Create a replay buffer for TD-MPC2 training. + Base class for TD-MPC2 replay buffers. Uses CUDA memory if available, and CPU memory otherwise. """ def __init__(self, cfg): self.cfg = cfg self._device = torch.device('cuda') - self._batch_size = self.cfg.batch_size * (self.cfg.horizon+1) - self._capacity = min(cfg.buffer_size, cfg.steps) + self._capacity = None + self._max_eps = None self._num_eps = 0 + self._sampler = None + self._transform = None + self._batch_size = None @property def capacity(self): @@ -36,15 +39,11 @@ class Buffer(): """ return ReplayBuffer( storage=storage, - sampler=SliceSampler( - num_slices=self.cfg.batch_size, - end_key=None, - traj_key='episode', - truncated_key=None, - ), + sampler=self._sampler, pin_memory=True, prefetch=1, - batch_size=self.cfg.batch_size, + transform=self._transform, + batch_size=self._batch_size, ) def _init(self, tds): @@ -53,45 +52,98 @@ class Buffer(): bytes_per_ep = sum([ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ else sum([x.numel()*x.element_size() for x in v.values()])) \ - for v in tds.values() - ]) + for k,v in tds.items() + ]) print(f'Bytes per episode: {bytes_per_ep:,}') - total_bytes = bytes_per_ep * (self._capacity // self.cfg.episode_length) + total_bytes = bytes_per_ep*self._max_eps print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory - if 2.5*total_bytes > mem_free: # Insufficient CUDA memory - print('Using CPU memory for storage.') - return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device('cpu')) - ) - else: # Sufficient CUDA memory - print('Using CUDA memory for storage.') - return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device('cuda')) - ) + storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' + print(f'Using {storage_device.upper()} memory for storage.') + return self._reserve_buffer( + LazyTensorStorage(self._capacity, device=torch.device(storage_device)) + ) - def add(self, tds): - """Add a step to the buffer.""" - tds['episode'] = torch.ones_like(tds['reward'], dtype=torch.int64) * self._num_eps - tds['step'] = torch.arange(0, len(tds)) + def _to_device(self, *args, device=None): + if device is None: + device = self._device + return (arg.to(device, non_blocking=True) \ + if arg is not None else None for arg in args) + + def _prepare_batch(self, td): + """ + Prepare a sampled batch for training (post-processing). + Expects `td` to be a TensorDict with batch size TxB. + """ + obs = td['obs'] + action = td['action'][1:] + reward = td['reward'][1:].unsqueeze(-1) + task = td['task'][0] if 'task' in td.keys() else None + return self._to_device(obs, action, reward, task) + + def add(self, td): + """Add an episode to the buffer.""" + pass + + def sample(self): + """Sample a batch of sub-trajectories from the buffer.""" + pass + + +class CropBuffer(Buffer): + """ + A replay buffer that first samples trajectories, and then crops to desired length. + """ + + def __init__(self, cfg): + super().__init__(cfg) + self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length + self._max_eps = self._capacity + self._sampler = RandomSampler() + self._transform = RandomCropTensorDict(cfg.horizon+1, -1) + self._batch_size = cfg.batch_size + + def add(self, td): + """Add an episode to the buffer. All episodes are expected to be equal length.""" if self._num_eps == 0: - self._buffer = self._init(tds) - self._buffer.extend(tds) + self._buffer = self._init(td) + self._buffer.add(td) self._num_eps += 1 return self._num_eps def sample(self): - """Sample a batch of sub-trajectories from the buffer.""" - td = self._buffer.sample(batch_size=self._batch_size) \ - .view(-1, self.cfg.horizon+1).permute(1, 0) - obs = td['obs'].to(self._device, non_blocking=True) - action = td['action'][1:].to(self._device, non_blocking=True) - reward = td['reward'][1:].unsqueeze(-1).to(self._device, non_blocking=True) - task = td['task'][0].to(self._device, non_blocking=True) if 'task' in td.keys() else None - return obs, action, reward, task - - def save(self): - """Save the buffer to disk. Useful for storing offline datasets.""" - td = self._buffer._storage._storage.cpu() - fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt' - torch.save(td, fp) + """Sample a batch of subsequences from the buffer.""" + td = self._buffer.sample().permute(1,0) + return self._prepare_batch(td) + + +class SliceBuffer(Buffer): + """ + A replay buffer that directly samples subsequences. More efficient than CropBuffer. + """ + + def __init__(self, cfg): + super().__init__(cfg) + self._capacity = min(cfg.buffer_size, cfg.steps) + self._max_eps = self._capacity//cfg.episode_length + self._sampler = SliceSampler( + num_slices=self.cfg.batch_size, + end_key=None, + traj_key='episode', + truncated_key=None, + ) + self._batch_size = cfg.batch_size * (cfg.horizon+1) + + def add(self, td): + """Add an episode to the buffer. Supports variable episode lengths.""" + td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps + if self._num_eps == 0: + self._buffer = self._init(td) + self._buffer.extend(td) + self._num_eps += 1 + return self._num_eps + + def sample(self): + """Sample a batch of subsequences from the buffer.""" + td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0) + return self._prepare_batch(td) diff --git a/tdmpc2/common/legacy_buffer.py b/tdmpc2/common/legacy_buffer.py deleted file mode 100644 index dbbfea6..0000000 --- a/tdmpc2/common/legacy_buffer.py +++ /dev/null @@ -1,115 +0,0 @@ -from pathlib import Path -import torch -from tensordict.tensordict import TensorDict -from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage -from torchrl.data.replay_buffers.samplers import RandomSampler -from torchrl.envs import RandomCropTensorDict, Transform, Compose - -from common.logger import make_dir - - -class DataPrepTransform(Transform): - """ - Preprocesses data for TD-MPC2 training. - Replay data is expected to be a TensorDict with the following keys: - obs: observations - action: actions - reward: rewards - task: task IDs (optional) - A TensorDict with T time steps has T+1 observations and T actions and rewards. - The first actions and rewards in each TensorDict are dummies and should be ignored. - """ - - def __init__(self): - super().__init__([]) - - def forward(self, td): - td = td.permute(1,0) - return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None) - - -class Buffer(): - """ - Create a replay buffer for TD-MPC2 training. - Uses CUDA memory if available, and CPU memory otherwise. - """ - - def __init__(self, cfg): - self.cfg = cfg - self._device = torch.device('cuda') - self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length - self._num_eps = 0 - - @property - def capacity(self): - """Return the capacity of the buffer.""" - return self._capacity - - @property - def num_eps(self): - """Return the number of episodes in the buffer.""" - return self._num_eps - - def _reserve_buffer(self, storage): - """ - Reserve a buffer with the given storage. - Uses the RandomSampler to sample trajectories, - and the RandomCropTensorDict transform to crop trajectories to the desired length. - DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates. - """ - return ReplayBuffer( - storage=storage, - sampler=RandomSampler(), - pin_memory=True, - prefetch=1, - transform=Compose( - RandomCropTensorDict(self.cfg.horizon+1, -1), - DataPrepTransform(), - ), - batch_size=self.cfg.batch_size, - ) - - def _init(self, tds): - """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" - mem_free, _ = torch.cuda.mem_get_info() - bytes_per_ep = sum([ - (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ - else sum([x.numel()*x.element_size() for x in v.values()])) \ - for k,v in tds.items() - ]) - print(f'Bytes per episode: {bytes_per_ep:,}') - total_bytes = bytes_per_ep*self._capacity - print(f'Storage required: {total_bytes/1e9:.2f} GB') - # Heuristic: decide whether to use CUDA or CPU memory - if 2.5*total_bytes > mem_free: # Insufficient CUDA memory - print('Using CPU memory for storage.') - return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device('cpu')) - ) - else: # Sufficient CUDA memory - print('Using CUDA memory for storage.') - return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device('cuda')) - ) - - def add(self, tds): - """Add an episode to the buffer. All episodes are expected to have the same length.""" - if self._num_eps == 0: - self._buffer = self._init(tds) - self._buffer.add(tds) - self._num_eps += 1 - return self._num_eps - - def sample(self): - """Sample a batch of sub-trajectories from the buffer.""" - obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size) - return obs.to(self._device, non_blocking=True), \ - action.to(self._device, non_blocking=True), \ - reward.to(self._device, non_blocking=True), \ - task.to(self._device, non_blocking=True) if task is not None else None - - def save(self): - """Save the buffer to disk. Useful for storing offline datasets.""" - td = self._buffer._storage._storage.cpu() - fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt' - torch.save(td, fp) diff --git a/tdmpc2/test_buffer.py b/tdmpc2/test_buffer.py new file mode 100644 index 0000000..7a58d08 --- /dev/null +++ b/tdmpc2/test_buffer.py @@ -0,0 +1,64 @@ +import os +os.environ['LAZY_LEGACY_OP'] = '0' + +import torch +import hydra +from tensordict.tensordict import TensorDict + +from common.buffer import CropBuffer, SliceBuffer + + +@hydra.main(config_name='config', config_path='.') +def test_buffer(cfg: dict): + cfg.episode_length = 12 + cfg.batch_size = 8 + + transitions0 = [TensorDict(dict( + obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t, + action=torch.tensor([-1.]) ** t, + reward=torch.tensor([1.]) * t, + ), batch_size=(1,)) for t in range(cfg.episode_length)] + episode0 = torch.cat(transitions0) + + transitions1 = [TensorDict(dict( + obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t, + action=(torch.tensor([-1.]) ** t) * 0.5, + reward=torch.tensor([-1.]) * t, + ), batch_size=(1,)) for t in range(cfg.episode_length)] + episode1 = torch.cat(transitions1) + + crop_buffer = CropBuffer(cfg) + slice_buffer = SliceBuffer(cfg) + + crop_buffer.add(episode0) + slice_buffer.add(episode0) + + crop_buffer.add(episode1) + slice_buffer.add(episode1) + + crop_obs, crop_action, crop_reward, _ = crop_buffer.sample() + slice_obs, slice_action, slice_reward, _ = slice_buffer.sample() + + assert crop_obs.shape == slice_obs.shape + assert crop_action.shape == slice_action.shape + assert crop_reward.shape == slice_reward.shape + + assert (crop_obs[1:] - crop_obs[:-1] == 1.).all() + assert (slice_obs[1:] - slice_obs[:-1] == 1.).all() + assert (crop_action.mean().abs() < 0.2) + assert (slice_action.mean().abs() < 0.2) + + crop_rewards, slice_rewards = [], [] + for _ in range(100_000): + _, _, crop_reward_, _ = crop_buffer.sample() + _, _, slice_reward_, _ = slice_buffer.sample() + crop_rewards.append(crop_reward_.mean()) + slice_rewards.append(slice_reward_.mean()) + + crop_rewards = torch.tensor(crop_rewards).mean() + slice_rewards = torch.tensor(slice_rewards).mean() + assert (crop_rewards - slice_rewards) < 0.1 + + +if __name__ == '__main__': + test_buffer() diff --git a/tdmpc2/train.py b/tdmpc2/train.py index ded21e3..5303e09 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -10,8 +10,7 @@ from termcolor import colored from common.parser import parse_cfg from common.seed import set_seed -from common.buffer import Buffer -# from common.legacy_buffer import Buffer +from common.buffer import CropBuffer, SliceBuffer from envs import make_env from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer @@ -52,7 +51,8 @@ def train(cfg: dict): cfg=cfg, env=make_env(cfg), agent=TDMPC2(cfg), - buffer=Buffer(cfg), + buffer=CropBuffer(cfg), + # buffer=SliceBuffer(cfg), logger=Logger(cfg), ) trainer.train() From ca4dfa1db35593cdb8636d7e0aaa21997384905d Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sat, 23 Dec 2023 09:39:13 -0800 Subject: [PATCH 07/10] further reduce buffer differences --- tdmpc2/common/buffer.py | 29 +++++++++++++++-------------- tdmpc2/test_buffer.py | 6 +++--- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 1b9330c..75d17d2 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -48,6 +48,7 @@ class Buffer(): def _init(self, tds): """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" + print('Buffer capacity:', self._capacity) mem_free, _ = torch.cuda.mem_get_info() bytes_per_ep = sum([ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ @@ -81,9 +82,18 @@ class Buffer(): task = td['task'][0] if 'task' in td.keys() else None return self._to_device(obs, action, reward, task) + def _add(self, td): + """Internal function that adds episode to the buffer.""" + pass + def add(self, td): """Add an episode to the buffer.""" - pass + td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps + if self._num_eps == 0: + self._buffer = self._init(td) + self._add(td) + self._num_eps += 1 + return self._num_eps def sample(self): """Sample a batch of sub-trajectories from the buffer.""" @@ -103,13 +113,9 @@ class CropBuffer(Buffer): self._transform = RandomCropTensorDict(cfg.horizon+1, -1) self._batch_size = cfg.batch_size - def add(self, td): - """Add an episode to the buffer. All episodes are expected to be equal length.""" - if self._num_eps == 0: - self._buffer = self._init(td) + def _add(self, td): + """Add an episode to the buffer, with trajectories as the leading dimension.""" self._buffer.add(td) - self._num_eps += 1 - return self._num_eps def sample(self): """Sample a batch of subsequences from the buffer.""" @@ -134,14 +140,9 @@ class SliceBuffer(Buffer): ) self._batch_size = cfg.batch_size * (cfg.horizon+1) - def add(self, td): - """Add an episode to the buffer. Supports variable episode lengths.""" - td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps - if self._num_eps == 0: - self._buffer = self._init(td) + def _add(self, td): + """Add an episode to the buffer, with transitions as the leading dimension.""" self._buffer.extend(td) - self._num_eps += 1 - return self._num_eps def sample(self): """Sample a batch of subsequences from the buffer.""" diff --git a/tdmpc2/test_buffer.py b/tdmpc2/test_buffer.py index 7a58d08..87e9212 100644 --- a/tdmpc2/test_buffer.py +++ b/tdmpc2/test_buffer.py @@ -10,19 +10,19 @@ from common.buffer import CropBuffer, SliceBuffer @hydra.main(config_name='config', config_path='.') def test_buffer(cfg: dict): - cfg.episode_length = 12 + cfg.episode_length = 11 cfg.batch_size = 8 transitions0 = [TensorDict(dict( obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t, - action=torch.tensor([-1.]) ** t, + action=torch.tensor([-1.]).unsqueeze(0) ** t, reward=torch.tensor([1.]) * t, ), batch_size=(1,)) for t in range(cfg.episode_length)] episode0 = torch.cat(transitions0) transitions1 = [TensorDict(dict( obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t, - action=(torch.tensor([-1.]) ** t) * 0.5, + action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5, reward=torch.tensor([-1.]) * t, ), batch_size=(1,)) for t in range(cfg.episode_length)] episode1 = torch.cat(transitions1) From 2f86a1e4d8f49abdf49acc1b65f77907abf7e0e9 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Mon, 25 Dec 2023 10:11:42 -0800 Subject: [PATCH 08/10] fix sampler https://github.com/pytorch/rl/pull/1762 --- tdmpc2/common/samplers.py | 8 ++++++-- tdmpc2/train.py | 3 +-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tdmpc2/common/samplers.py b/tdmpc2/common/samplers.py index af0e073..a8cd72f 100644 --- a/tdmpc2/common/samplers.py +++ b/tdmpc2/common/samplers.py @@ -233,7 +233,9 @@ class SliceSampler(Sampler): and self._used_traj_key[0] == "_data" ) vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)]) - return self._cache.setdefault("stop-and-length", vals) + if self.cache_values: + self._cache["stop-and-length"] = vals + return vals except KeyError: if fallback: self._fetch_traj = False @@ -257,7 +259,9 @@ class SliceSampler(Sampler): and self._used_end_key[0] == "_data" ) vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)] - return self._cache.setdefault("stop-and-length", vals) + if self.cache_values: + self._cache["stop-and-length"] = vals + return vals except KeyError: if fallback: self._fetch_traj = True diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 5303e09..b091bec 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -51,8 +51,7 @@ def train(cfg: dict): cfg=cfg, env=make_env(cfg), agent=TDMPC2(cfg), - buffer=CropBuffer(cfg), - # buffer=SliceBuffer(cfg), + buffer=SliceBuffer(cfg), logger=Logger(cfg), ) trainer.train() From 54145a4d8c4c080836ff1f186fc5a87f70c8a8c7 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Wed, 27 Dec 2023 08:49:04 -0800 Subject: [PATCH 09/10] integrate slicesampler as default --- tdmpc2/common/buffer.py | 84 +++++++-------------------------------- tdmpc2/common/samplers.py | 28 +++---------- tdmpc2/test_buffer.py | 64 ----------------------------- tdmpc2/train.py | 4 +- 4 files changed, 22 insertions(+), 158 deletions(-) delete mode 100644 tdmpc2/test_buffer.py diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 75d17d2..29cc293 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -1,27 +1,28 @@ import torch from tensordict.tensordict import TensorDict from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage -from torchrl.data.replay_buffers.samplers import RandomSampler -from torchrl.envs import RandomCropTensorDict from common.samplers import SliceSampler class Buffer(): """ - Base class for TD-MPC2 replay buffers. + Replay buffer for TD-MPC2 training. Based on torchrl. Uses CUDA memory if available, and CPU memory otherwise. """ def __init__(self, cfg): self.cfg = cfg self._device = torch.device('cuda') - self._capacity = None - self._max_eps = None + self._capacity = min(cfg.buffer_size, cfg.steps) + self._sampler = SliceSampler( + num_slices=self.cfg.batch_size, + end_key=None, + traj_key='episode', + truncated_key=None, + ) + self._batch_size = cfg.batch_size * (cfg.horizon+1) self._num_eps = 0 - self._sampler = None - self._transform = None - self._batch_size = None @property def capacity(self): @@ -42,21 +43,19 @@ class Buffer(): sampler=self._sampler, pin_memory=True, prefetch=1, - transform=self._transform, batch_size=self._batch_size, ) def _init(self, tds): """Initialize the replay buffer. Use the first episode to estimate storage requirements.""" - print('Buffer capacity:', self._capacity) + print(f'Buffer capacity: {self._capacity:,}') mem_free, _ = torch.cuda.mem_get_info() - bytes_per_ep = sum([ + bytes_per_step = sum([ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \ else sum([x.numel()*x.element_size() for x in v.values()])) \ - for k,v in tds.items() - ]) - print(f'Bytes per episode: {bytes_per_ep:,}') - total_bytes = bytes_per_ep*self._max_eps + for v in tds.values() + ]) / len(tds) + total_bytes = bytes_per_step*self._capacity print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' @@ -82,68 +81,15 @@ class Buffer(): task = td['task'][0] if 'task' in td.keys() else None return self._to_device(obs, action, reward, task) - def _add(self, td): - """Internal function that adds episode to the buffer.""" - pass - def add(self, td): """Add an episode to the buffer.""" td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps if self._num_eps == 0: self._buffer = self._init(td) - self._add(td) + self._buffer.extend(td) self._num_eps += 1 return self._num_eps - def sample(self): - """Sample a batch of sub-trajectories from the buffer.""" - pass - - -class CropBuffer(Buffer): - """ - A replay buffer that first samples trajectories, and then crops to desired length. - """ - - def __init__(self, cfg): - super().__init__(cfg) - self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length - self._max_eps = self._capacity - self._sampler = RandomSampler() - self._transform = RandomCropTensorDict(cfg.horizon+1, -1) - self._batch_size = cfg.batch_size - - def _add(self, td): - """Add an episode to the buffer, with trajectories as the leading dimension.""" - self._buffer.add(td) - - def sample(self): - """Sample a batch of subsequences from the buffer.""" - td = self._buffer.sample().permute(1,0) - return self._prepare_batch(td) - - -class SliceBuffer(Buffer): - """ - A replay buffer that directly samples subsequences. More efficient than CropBuffer. - """ - - def __init__(self, cfg): - super().__init__(cfg) - self._capacity = min(cfg.buffer_size, cfg.steps) - self._max_eps = self._capacity//cfg.episode_length - self._sampler = SliceSampler( - num_slices=self.cfg.batch_size, - end_key=None, - traj_key='episode', - truncated_key=None, - ) - self._batch_size = cfg.batch_size * (cfg.horizon+1) - - def _add(self, td): - """Add an episode to the buffer, with transitions as the leading dimension.""" - self._buffer.extend(td) - def sample(self): """Sample a batch of subsequences from the buffer.""" td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0) diff --git a/tdmpc2/common/samplers.py b/tdmpc2/common/samplers.py index a8cd72f..71cf70c 100644 --- a/tdmpc2/common/samplers.py +++ b/tdmpc2/common/samplers.py @@ -1,36 +1,18 @@ from __future__ import annotations -import json -import warnings -from abc import ABC, abstractmethod -from copy import copy, deepcopy -from multiprocessing.context import get_spawning_popen -from pathlib import Path -from typing import Any, Dict, Tuple, Union +from copy import copy +from typing import Tuple -import numpy as np import torch -from tensordict import MemoryMappedTensor from tensordict.utils import NestedKey -from torchrl._extension import EXTENSION_WARNING - -try: - from torchrl._torchrl import ( - MinSegmentTreeFp32, - MinSegmentTreeFp64, - SumSegmentTreeFp32, - SumSegmentTreeFp64, - ) -except ImportError: - warnings.warn(EXTENSION_WARNING) - from torchrl.data.replay_buffers.storages import Storage, TensorStorage -from torchrl.data.replay_buffers.utils import _to_numpy, INT_CLASSES from torchrl.data.replay_buffers.samplers import Sampler -_EMPTY_STORAGE_ERROR = "Cannot sample from an empty storage." + +# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html +# This copy will live here until it has been included in a few torchrl stable releases class SliceSampler(Sampler): diff --git a/tdmpc2/test_buffer.py b/tdmpc2/test_buffer.py deleted file mode 100644 index 87e9212..0000000 --- a/tdmpc2/test_buffer.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -os.environ['LAZY_LEGACY_OP'] = '0' - -import torch -import hydra -from tensordict.tensordict import TensorDict - -from common.buffer import CropBuffer, SliceBuffer - - -@hydra.main(config_name='config', config_path='.') -def test_buffer(cfg: dict): - cfg.episode_length = 11 - cfg.batch_size = 8 - - transitions0 = [TensorDict(dict( - obs=torch.tensor([0., 1., 2., 3., 4.]).unsqueeze(0) + t, - action=torch.tensor([-1.]).unsqueeze(0) ** t, - reward=torch.tensor([1.]) * t, - ), batch_size=(1,)) for t in range(cfg.episode_length)] - episode0 = torch.cat(transitions0) - - transitions1 = [TensorDict(dict( - obs=torch.tensor([20., 21., 22., 23., 24.]).unsqueeze(0) + t, - action=(torch.tensor([-1.]) ** t).unsqueeze(0) * 0.5, - reward=torch.tensor([-1.]) * t, - ), batch_size=(1,)) for t in range(cfg.episode_length)] - episode1 = torch.cat(transitions1) - - crop_buffer = CropBuffer(cfg) - slice_buffer = SliceBuffer(cfg) - - crop_buffer.add(episode0) - slice_buffer.add(episode0) - - crop_buffer.add(episode1) - slice_buffer.add(episode1) - - crop_obs, crop_action, crop_reward, _ = crop_buffer.sample() - slice_obs, slice_action, slice_reward, _ = slice_buffer.sample() - - assert crop_obs.shape == slice_obs.shape - assert crop_action.shape == slice_action.shape - assert crop_reward.shape == slice_reward.shape - - assert (crop_obs[1:] - crop_obs[:-1] == 1.).all() - assert (slice_obs[1:] - slice_obs[:-1] == 1.).all() - assert (crop_action.mean().abs() < 0.2) - assert (slice_action.mean().abs() < 0.2) - - crop_rewards, slice_rewards = [], [] - for _ in range(100_000): - _, _, crop_reward_, _ = crop_buffer.sample() - _, _, slice_reward_, _ = slice_buffer.sample() - crop_rewards.append(crop_reward_.mean()) - slice_rewards.append(slice_reward_.mean()) - - crop_rewards = torch.tensor(crop_rewards).mean() - slice_rewards = torch.tensor(slice_rewards).mean() - assert (crop_rewards - slice_rewards) < 0.1 - - -if __name__ == '__main__': - test_buffer() diff --git a/tdmpc2/train.py b/tdmpc2/train.py index b091bec..5953bb2 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -10,7 +10,7 @@ from termcolor import colored from common.parser import parse_cfg from common.seed import set_seed -from common.buffer import CropBuffer, SliceBuffer +from common.buffer import Buffer from envs import make_env from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer @@ -51,7 +51,7 @@ def train(cfg: dict): cfg=cfg, env=make_env(cfg), agent=TDMPC2(cfg), - buffer=SliceBuffer(cfg), + buffer=Buffer(cfg), logger=Logger(cfg), ) trainer.train() From 6cb779aa3a294260b7522570193cb78835124864 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Thu, 28 Dec 2023 07:33:03 -0800 Subject: [PATCH 10/10] allow missing env dependencies + update readme --- README.md | 14 ++++++++---- docker/environment.yaml | 1 + docker/environment_minimal.yaml | 39 ++++++++++++++++++++++++++++++++ tdmpc2/envs/__init__.py | 35 ++++++++++++++++++++-------- tdmpc2/envs/dmcontrol.py | 4 ++-- tdmpc2/envs/exceptions.py | 4 ---- tdmpc2/envs/maniskill.py | 4 ++-- tdmpc2/envs/metaworld.py | 4 ++-- tdmpc2/envs/myosuite.py | 28 ++++++++++------------- tdmpc2/trainer/online_trainer.py | 2 +- 10 files changed, 95 insertions(+), 40 deletions(-) create mode 100644 docker/environment_minimal.yaml delete mode 100644 tdmpc2/envs/exceptions.py diff --git a/README.md b/README.md index a3ee4f1..94534a4 100755 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
-This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. We hope that this repository will serve as a useful community resource for future research on model-based RL. +This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL. ---- @@ -32,12 +32,15 @@ We provide a `Dockerfile` for easy installation. You can build the docker image cd docker && docker build . -t /tdmpc2:0.1.0 ``` -If you prefer to install dependencies manually, start by installing dependencies via `conda` by running +If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands: ``` conda env create -f docker/environment.yaml +conda env create -f docker/environment_minimal.yaml ``` +The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks. + If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running ``` @@ -72,11 +75,13 @@ This codebase currently supports **104** continuous control tasks from **DMContr | metaworld | mw-pick-place-wall | maniskill | pick-cube | maniskill | pick-ycb -| myosuite | myo-hand-key-turn -| myosuite | myo-hand-key-turn-hard +| myosuite | myo-key-turn +| myosuite | myo-key-turn-hard which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. +**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies. + ## Example usage @@ -102,6 +107,7 @@ See below examples on how to train TD-MPC**2** on a single task (online RL) and $ python train.py task=mt80 model_size=48 batch_size=1024 $ python train.py task=mt30 model_size=317 batch_size=1024 $ python train.py task=dog-run steps=7000000 +$ python train.py task=walker-walk obs=rgb ``` We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments. diff --git a/docker/environment.yaml b/docker/environment.yaml index 18a9914..6792839 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -26,6 +26,7 @@ dependencies: - hydra-core - hydra-submitit-launcher - submitit + - pandas - patchelf - protobuf - tqdm diff --git a/docker/environment_minimal.yaml b/docker/environment_minimal.yaml new file mode 100644 index 0000000..fbe30f6 --- /dev/null +++ b/docker/environment_minimal.yaml @@ -0,0 +1,39 @@ +name: tdmpc2 +channels: + - pytorch-nightly + - nvidia + - conda-forge + - defaults +dependencies: + - python=3.9.0 + - pytorch + - torchvision + - cudatoolkit=11.7 + - glew + - glib + - pip==21 + - pip: + - absl-py + - glfw + - kornia + - termcolor + - gym==0.21.0 + - moviepy + - ffmpeg + - imageio + - imageio-ffmpeg + - omegaconf + - hydra-core + - hydra-submitit-launcher + - submitit + - pandas + - patchelf + - protobuf + - tqdm + - setuptools==65.5.0 + - "cython<3" + - dm-control + - pillow + - tensordict-nightly + - torchrl-nightly + - wandb diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 5efcb73..6326a9e 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -6,11 +6,27 @@ import gym from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.pixels import PixelWrapper from envs.wrappers.tensor import TensorWrapper -from envs.dmcontrol import make_env as make_dm_control_env -# from envs.maniskill import make_env as make_maniskill_env -# from envs.metaworld import make_env as make_metaworld_env -# from envs.myosuite import make_env as make_myosuite_env -from envs.exceptions import UnknownTaskError + +def missing_dependencies(task): + raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') + +try: + from envs.dmcontrol import make_env as make_dm_control_env +except: + make_dm_control_env = missing_dependencies +try: + from envs.maniskill import make_env as make_maniskill_env +except: + make_maniskill_env = missing_dependencies +try: + from envs.metaworld import make_env as make_metaworld_env +except: + make_metaworld_env = missing_dependencies +try: + from envs.myosuite import make_env as make_myosuite_env +except: + make_myosuite_env = missing_dependencies + warnings.filterwarnings('ignore', category=DeprecationWarning) @@ -27,7 +43,7 @@ def make_multitask_env(cfg): _cfg.multitask = False env = make_env(_cfg) if env is None: - raise UnknownTaskError(task) + raise ValueError('Unknown task:', task) envs.append(env) env = MultitaskWrapper(cfg, envs) cfg.obs_shapes = env._obs_dims @@ -43,15 +59,16 @@ def make_env(cfg): gym.logger.set_level(40) if cfg.multitask: env = make_multitask_env(cfg) + else: env = None - for fn in [make_dm_control_env]: #, make_maniskill_env, make_metaworld_env, make_myosuite_env]: + for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: try: env = fn(cfg) - except UnknownTaskError: + except ValueError: pass if env is None: - raise UnknownTaskError(cfg.task) + raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') env = TensorWrapper(env) if cfg.get('obs', 'state') == 'rgb': env = PixelWrapper(cfg, env) diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index 32cb4b6..97be75a 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -8,7 +8,6 @@ suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom') suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS) from dm_control.suite.wrappers import action_scale from dm_env import StepType, specs -from envs.exceptions import UnknownTaskError import gym @@ -187,7 +186,8 @@ def make_env(cfg): domain, task = cfg.task.replace('-', '_').split('_', 1) domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain) if (domain, task) not in suite.ALL_TASKS: - raise UnknownTaskError(cfg.task) + raise ValueError('Unknown task:', task) + assert cfg.obs in {'state', 'rgb'}, 'This task only supports state and rgb observations.' env = suite.load(domain, task, task_kwargs={'random': cfg.seed}, diff --git a/tdmpc2/envs/exceptions.py b/tdmpc2/envs/exceptions.py deleted file mode 100644 index 9bf1390..0000000 --- a/tdmpc2/envs/exceptions.py +++ /dev/null @@ -1,4 +0,0 @@ - -class UnknownTaskError(Exception): - def __init__(self, task): - super().__init__(f'Unknown task: {task}') diff --git a/tdmpc2/envs/maniskill.py b/tdmpc2/envs/maniskill.py index 1d2e4c9..7b0b6ed 100644 --- a/tdmpc2/envs/maniskill.py +++ b/tdmpc2/envs/maniskill.py @@ -1,7 +1,6 @@ import gym import numpy as np from envs.wrappers.time_limit import TimeLimit -from envs.exceptions import UnknownTaskError import mani_skill2.envs @@ -65,7 +64,8 @@ def make_env(cfg): Make ManiSkill2 environment. """ if cfg.task not in MANISKILL_TASKS: - raise UnknownTaskError(cfg.task) + raise ValueError('Unknown task:', cfg.task) + assert cfg.obs == 'state', 'This task only supports state observations.' task_cfg = MANISKILL_TASKS[cfg.task] env = gym.make( task_cfg['env'], diff --git a/tdmpc2/envs/metaworld.py b/tdmpc2/envs/metaworld.py index fd7379d..f5f4f0d 100644 --- a/tdmpc2/envs/metaworld.py +++ b/tdmpc2/envs/metaworld.py @@ -1,7 +1,6 @@ import numpy as np import gym from envs.wrappers.time_limit import TimeLimit -from envs.exceptions import UnknownTaskError from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE @@ -44,7 +43,8 @@ def make_env(cfg): """ env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable" if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE: - raise UnknownTaskError(cfg.task) + raise ValueError('Unknown task:', cfg.task) + assert cfg.obs == 'state', 'This task only supports state observations.' env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed) env = MetaWorldWrapper(env, cfg) env = TimeLimit(env, max_episode_steps=100) diff --git a/tdmpc2/envs/myosuite.py b/tdmpc2/envs/myosuite.py index c503782..fa6876e 100644 --- a/tdmpc2/envs/myosuite.py +++ b/tdmpc2/envs/myosuite.py @@ -1,24 +1,19 @@ import numpy as np import gym from envs.wrappers.time_limit import TimeLimit -from envs.exceptions import UnknownTaskError MYOSUITE_TASKS = { - 'myo-finger-reach': 'myoFingerReachFixed-v0', - 'myo-finger-reach-hard': 'myoFingerReachRandom-v0', - 'myo-finger-pose': 'myoFingerPoseFixed-v0', - 'myo-finger-pose-hard': 'myoFingerPoseRandom-v0', - 'myo-hand-reach': 'myoHandReachFixed-v0', - 'myo-hand-reach-hard': 'myoHandReachRandom-v0', - 'myo-hand-pose': 'myoHandPoseFixed-v0', - 'myo-hand-pose-hard': 'myoHandPoseRandom-v0', - 'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0', - 'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0', - 'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0', - 'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0', - 'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0', - 'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0', + 'myo-reach': 'myoHandReachFixed-v0', + 'myo-reach-hard': 'myoHandReachRandom-v0', + 'myo-pose': 'myoHandPoseFixed-v0', + 'myo-pose-hard': 'myoHandPoseRandom-v0', + 'myo-obj-hold': 'myoHandObjHoldFixed-v0', + 'myo-obj-hold-hard': 'myoHandObjHoldRandom-v0', + 'myo-key-turn': 'myoHandKeyTurnFixed-v0', + 'myo-key-turn-hard': 'myoHandKeyTurnRandom-v0', + 'myo-pen-twirl': 'myoHandPenTwirlFixed-v0', + 'myo-pen-twirl-hard': 'myoHandPenTwirlRandom-v0', } @@ -50,7 +45,8 @@ def make_env(cfg): Make Myosuite environment. """ if not cfg.task in MYOSUITE_TASKS: - raise UnknownTaskError(cfg.task) + raise ValueError('Unknown task:', cfg.task) + assert cfg.obs == 'state', 'This task only supports state observations.' import myosuite env = gym.make(MYOSUITE_TASKS[cfg.task]) env = MyoSuiteWrapper(env, cfg) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index f5f65cc..ca33009 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -54,7 +54,7 @@ class OnlineTrainer(Trainer): else: obs = obs.unsqueeze(0).cpu() if action is None: - action = torch.empty_like(self.env.rand_act()) + action = torch.full_like(self.env.rand_act(), float('nan')) if reward is None: reward = torch.tensor(float('nan')) td = TensorDict(dict(