Merge pull request #10 from nicklashansen/experimental

[Feature] Faster replay buffer + support pixel observations
This commit is contained in:
Nicklas Hansen
2023-12-28 16:37:27 +01:00
committed by GitHub
13 changed files with 494 additions and 106 deletions

View File

@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
<img src="assets/8.png" width="100%" style="max-width: 640px"><br/>
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. We hope that this repository will serve as a useful community resource for future research on model-based RL.
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
----
@@ -32,12 +32,15 @@ We provide a `Dockerfile` for easy installation. You can build the docker image
cd docker && docker build . -t <user>/tdmpc2:0.1.0
```
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:
```
conda env create -f docker/environment.yaml
conda env create -f docker/environment_minimal.yaml
```
The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
```
@@ -72,11 +75,13 @@ This codebase currently supports **104** continuous control tasks from **DMContr
| metaworld | mw-pick-place-wall
| maniskill | pick-cube
| maniskill | pick-ycb
| myosuite | myo-hand-key-turn
| myosuite | myo-hand-key-turn-hard
| myosuite | myo-key-turn
| myosuite | myo-key-turn-hard
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.
**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.
## Example usage
@@ -102,6 +107,7 @@ See below examples on how to train TD-MPC**2** on a single task (online RL) and
$ python train.py task=mt80 model_size=48 batch_size=1024
$ python train.py task=mt30 model_size=317 batch_size=1024
$ python train.py task=dog-run steps=7000000
$ python train.py task=walker-walk obs=rgb
```
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.

View File

@@ -26,6 +26,7 @@ dependencies:
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm

View File

@@ -0,0 +1,39 @@
name: tdmpc2
channels:
- pytorch-nightly
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7
- glew
- glib
- pip==21
- pip:
- absl-py
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- setuptools==65.5.0
- "cython<3"
- dm-control
- pillow
- tensordict-nightly
- torchrl-nightly
- wandb

View File

@@ -1,43 +1,27 @@
from pathlib import Path
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import RandomSampler
from torchrl.envs import RandomCropTensorDict, Transform, Compose
from common.logger import make_dir
class DataPrepTransform(Transform):
"""
Preprocesses data for TD-MPC2 training.
Replay data is expected to be a TensorDict with the following keys:
obs: observations
action: actions
reward: rewards
task: task IDs (optional)
A TensorDict with T time steps has T+1 observations and T actions and rewards.
The first actions and rewards in each TensorDict are dummies and should be ignored.
"""
def __init__(self):
super().__init__([])
def forward(self, td):
td = td.permute(1,0)
return td['obs'], td['action'][1:], td['reward'][1:].unsqueeze(-1), (td['task'][0] if 'task' in td.keys() else None)
from common.samplers import SliceSampler
class Buffer():
"""
Create a replay buffer for TD-MPC2 training.
Replay buffer for TD-MPC2 training. Based on torchrl.
Uses CUDA memory if available, and CPU memory otherwise.
"""
def __init__(self, cfg):
self.cfg = cfg
self._device = torch.device('cuda')
self._capacity = min(cfg.buffer_size, cfg.steps)//cfg.episode_length
self._capacity = min(cfg.buffer_size, cfg.steps)
self._sampler = SliceSampler(
num_slices=self.cfg.batch_size,
end_key=None,
traj_key='episode',
truncated_key=None,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0
@property
@@ -53,63 +37,60 @@ class Buffer():
def _reserve_buffer(self, storage):
"""
Reserve a buffer with the given storage.
Uses the RandomSampler to sample trajectories,
and the RandomCropTensorDict transform to crop trajectories to the desired length.
DataPrepTransform is used to preprocess data to the expected format in TD-MPC2 updates.
"""
return ReplayBuffer(
storage=storage,
sampler=RandomSampler(),
sampler=self._sampler,
pin_memory=True,
prefetch=1,
transform=Compose(
RandomCropTensorDict(self.cfg.horizon+1, -1),
DataPrepTransform(),
),
batch_size=self.cfg.batch_size,
batch_size=self._batch_size,
)
def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
print(f'Buffer capacity: {self._capacity:,}')
mem_free, _ = torch.cuda.mem_get_info()
bytes_per_ep = sum([
bytes_per_step = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
else sum([x.numel()*x.element_size() for x in v.values()])) \
for k,v in tds.items()
])
print(f'Bytes per episode: {bytes_per_ep:,}')
total_bytes = bytes_per_ep*self._capacity
for v in tds.values()
]) / len(tds)
total_bytes = bytes_per_step*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory
if 2.5*total_bytes > mem_free: # Insufficient CUDA memory
print('Using CPU memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cpu'))
)
else: # Sufficient CUDA memory
print('Using CUDA memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device('cuda'))
)
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
print(f'Using {storage_device.upper()} memory for storage.')
return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
)
def add(self, tds):
"""Add an episode to the buffer. All episodes are expected to have the same length."""
def _to_device(self, *args, device=None):
if device is None:
device = self._device
return (arg.to(device, non_blocking=True) \
if arg is not None else None for arg in args)
def _prepare_batch(self, td):
"""
Prepare a sampled batch for training (post-processing).
Expects `td` to be a TensorDict with batch size TxB.
"""
obs = td['obs']
action = td['action'][1:]
reward = td['reward'][1:].unsqueeze(-1)
task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, task)
def add(self, td):
"""Add an episode to the buffer."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
if self._num_eps == 0:
self._buffer = self._init(tds)
self._buffer.add(tds)
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_eps += 1
return self._num_eps
def sample(self):
"""Sample a batch of sub-trajectories from the buffer."""
obs, action, reward, task = self._buffer.sample(batch_size=self.cfg.batch_size)
return obs.to(self._device, non_blocking=True), \
action.to(self._device, non_blocking=True), \
reward.to(self._device, non_blocking=True), \
task.to(self._device, non_blocking=True) if task is not None else None
def save(self):
"""Save the buffer to disk. Useful for storing offline datasets."""
td = self._buffer._storage._storage.cpu()
fp = make_dir(Path(self.cfg.buffer_dir) / self.cfg.task / str(self.cfg.seed)) / f'{self._num_eps}.pt'
torch.save(td, fp)
"""Sample a batch of subsequences from the buffer."""
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
return self._prepare_batch(td)

351
tdmpc2/common/samplers.py Normal file
View File

@@ -0,0 +1,351 @@
from __future__ import annotations
from copy import copy
from typing import Tuple
import torch
from tensordict.utils import NestedKey
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.samplers import Sampler
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
# This copy will live here until it has been included in a few torchrl stable releases
class SliceSampler(Sampler):
"""Samples slices of data along the first dimension, given start and stop signals.
This class samples sub-trajectories with replacement. For a version without
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
with ``slice_len``.
slice_len (int): the length of the slices to be sampled. The batch-size
must be greater or equal to the ``slice_len`` argument and divisible
by it. Exclusive with ``num_slices``.
end_key (NestedKey, optional): the key indicating the end of a
trajectory (or episode). Defaults to ``("next", "done")``.
traj_key (NestedKey, optional): the key indicating the trajectories.
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
trajectory breaks. Defaults to ``("next", "truncated")``.
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
instances (otherwise the truncated key is returned in the info dictionary
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
attempt to find the ``traj_key`` entry in the storage. If it cannot be
found, the ``end_key`` will be used to reconstruct the episodes.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
... storage=LazyMemmapStorage(1_000_000),
... sampler=SliceSampler(cache_values=True, num_slices=10),
... batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
... {
... "episode": episode,
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
... "act": torch.randn((20,)).expand(1000, 20),
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
... }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
most of TorchRL's datasets:
Examples:
>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
... batch = batch.reshape(num_slices, -1)
... break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
[14],
[ 8],
[10],
[13],
[ 4],
[ 2],
[ 3],
[22],
[ 8]])
"""
def __init__(
self,
*,
num_slices: int = None,
slice_len: int = None,
end_key: NestedKey | None = None,
traj_key: NestedKey | None = None,
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
) -> object:
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
traj_key = "episode"
if not ((num_slices is None) ^ (slice_len is None)):
raise TypeError(
"Either num_slices or slice_len must be not None, and not both. "
f"Got num_slices={num_slices} and slice_len={slice_len}."
)
self.num_slices = num_slices
self.slice_len = slice_len
self.end_key = end_key
self.traj_key = traj_key
self.truncated_key = truncated_key
self.cache_values = cache_values
self._fetch_traj = True
self._uses_data_prefix = False
self.strict_length = strict_length
self._cache = {}
@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
# stop_idx = stop_idx.cumsum(0) - 1
# even slower
# t = trajectory.unsqueeze(0)
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
# stop_idx = torch.conv1d(t, w).nonzero()
# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
if end.ndim != 1:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
)
stop_idx = end.view(-1).nonzero().view(-1)
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
lengths = stop_idx - start_idx + 1
return start_idx, stop_idx, lengths
def _tensor_slices_from_startend(self, seq_length, start):
if isinstance(seq_length, int):
return (
torch.arange(
seq_length, device=start.device, dtype=start.dtype
).unsqueeze(0)
+ start.unsqueeze(1)
).view(-1)
else:
# when padding is needed
return torch.cat(
[
_start
+ torch.arange(_seq_len, device=start.device, dtype=start.dtype)
for _start, _seq_len in zip(start, seq_length)
]
)
def _get_stop_and_length(self, storage, fallback=True):
if self.cache_values and "stop-and-length" in self._cache:
return self._cache.get("stop-and-length")
if self._fetch_traj:
# We first try with the traj_key
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
trajectory = storage._storage.get(self._used_traj_key)
except KeyError:
trajectory = storage._storage.get(("_data", self.traj_key))
# cache that value for future use
self._used_traj_key = ("_data", self.traj_key)
self._uses_data_prefix = (
isinstance(self._used_traj_key, tuple)
and self._used_traj_key[0] == "_data"
)
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = False
return self._get_stop_and_length(storage, fallback=False)
raise
else:
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
done = storage._storage.get(self._used_end_key)
except KeyError:
done = storage._storage.get(("_data", self.end_key))
# cache that value for future use
self._used_end_key = ("_data", self.end_key)
self._uses_data_prefix = (
isinstance(self._used_end_key, tuple)
and self._used_end_key[0] == "_data"
)
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = True
return self._get_stop_and_length(storage, fallback=False)
raise
def _adjusted_batch_size(self, batch_size):
if self.num_slices is not None:
if batch_size % self.num_slices != 0:
raise RuntimeError(
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
)
seq_length = batch_size // self.num_slices
num_slices = self.num_slices
else:
if batch_size % self.slice_len != 0:
raise RuntimeError(
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
)
seq_length = self.slice_len
num_slices = batch_size // self.slice_len
return seq_length, num_slices
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if not isinstance(storage, TensorStorage):
raise RuntimeError(
f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
)
# pick up as many trajs as we need
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
return self._sample_slices(lengths, start_idx, seq_length, num_slices)
def _sample_slices(
self, lengths, start_idx, seq_length, num_slices, traj_idx=None
) -> Tuple[torch.Tensor, dict]:
if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for. "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
# make seq_length a tensor with values clamped by lengths
seq_length = lengths.clamp_max(seq_length)
if traj_idx is None:
traj_idx = torch.randint(
lengths.shape[0], (num_slices,), device=lengths.device
)
else:
num_slices = traj_idx.shape[0]
relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
* (lengths[traj_idx] - seq_length)
)
.floor()
.to(start_idx.dtype)
)
starts = start_idx[traj_idx] + relative_starts
index = self._tensor_slices_from_startend(seq_length, starts)
if self.truncated_key is not None:
truncated_key = self.truncated_key
truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
return index.to(torch.long), {truncated_key: truncated}
return index.to(torch.long), {}
@property
def _used_traj_key(self):
return self.__dict__.get("__used_traj_key", self.traj_key)
@_used_traj_key.setter
def _used_traj_key(self, value):
self.__dict__["__used_traj_key"] = value
@property
def _used_end_key(self):
return self.__dict__.get("__used_end_key", self.end_key)
@_used_end_key.setter
def _used_end_key(self, value):
self.__dict__["__used_end_key"] = value
def _empty(self):
pass
def dumps(self, path):
# no op - cache does not need to be saved
...
def loads(self, path):
# no op
...
def __getstate__(self):
state = copy(self.__dict__)
state["_cache"] = {}
return state

View File

@@ -6,11 +6,27 @@ import gym
from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper
from envs.dmcontrol import make_env as make_dm_control_env
from envs.maniskill import make_env as make_maniskill_env
from envs.metaworld import make_env as make_metaworld_env
from envs.myosuite import make_env as make_myosuite_env
from envs.exceptions import UnknownTaskError
def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
try:
from envs.dmcontrol import make_env as make_dm_control_env
except:
make_dm_control_env = missing_dependencies
try:
from envs.maniskill import make_env as make_maniskill_env
except:
make_maniskill_env = missing_dependencies
try:
from envs.metaworld import make_env as make_metaworld_env
except:
make_metaworld_env = missing_dependencies
try:
from envs.myosuite import make_env as make_myosuite_env
except:
make_myosuite_env = missing_dependencies
warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -27,7 +43,7 @@ def make_multitask_env(cfg):
_cfg.multitask = False
env = make_env(_cfg)
if env is None:
raise UnknownTaskError(task)
raise ValueError('Unknown task:', task)
envs.append(env)
env = MultitaskWrapper(cfg, envs)
cfg.obs_shapes = env._obs_dims
@@ -43,15 +59,16 @@ def make_env(cfg):
gym.logger.set_level(40)
if cfg.multitask:
env = make_multitask_env(cfg)
else:
env = None
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try:
env = fn(cfg)
except UnknownTaskError:
except ValueError:
pass
if env is None:
raise UnknownTaskError(cfg.task)
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)

View File

@@ -8,7 +8,6 @@ suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs
from envs.exceptions import UnknownTaskError
import gym
@@ -187,7 +186,8 @@ def make_env(cfg):
domain, task = cfg.task.replace('-', '_').split('_', 1)
domain = dict(cup='ball_in_cup', pointmass='point_mass').get(domain, domain)
if (domain, task) not in suite.ALL_TASKS:
raise UnknownTaskError(cfg.task)
raise ValueError('Unknown task:', task)
assert cfg.obs in {'state', 'rgb'}, 'This task only supports state and rgb observations.'
env = suite.load(domain,
task,
task_kwargs={'random': cfg.seed},

View File

@@ -1,4 +0,0 @@
class UnknownTaskError(Exception):
def __init__(self, task):
super().__init__(f'Unknown task: {task}')

View File

@@ -1,7 +1,6 @@
import gym
import numpy as np
from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
import mani_skill2.envs
@@ -65,7 +64,8 @@ def make_env(cfg):
Make ManiSkill2 environment.
"""
if cfg.task not in MANISKILL_TASKS:
raise UnknownTaskError(cfg.task)
raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
task_cfg = MANISKILL_TASKS[cfg.task]
env = gym.make(
task_cfg['env'],

View File

@@ -1,7 +1,6 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
@@ -44,7 +43,8 @@ def make_env(cfg):
"""
env_id = cfg.task.split("-", 1)[-1] + "-v2-goal-observable"
if not cfg.task.startswith('mw-') or env_id not in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE:
raise UnknownTaskError(cfg.task)
raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
env = MetaWorldWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100)

View File

@@ -1,24 +1,19 @@
import numpy as np
import gym
from envs.wrappers.time_limit import TimeLimit
from envs.exceptions import UnknownTaskError
MYOSUITE_TASKS = {
'myo-finger-reach': 'myoFingerReachFixed-v0',
'myo-finger-reach-hard': 'myoFingerReachRandom-v0',
'myo-finger-pose': 'myoFingerPoseFixed-v0',
'myo-finger-pose-hard': 'myoFingerPoseRandom-v0',
'myo-hand-reach': 'myoHandReachFixed-v0',
'myo-hand-reach-hard': 'myoHandReachRandom-v0',
'myo-hand-pose': 'myoHandPoseFixed-v0',
'myo-hand-pose-hard': 'myoHandPoseRandom-v0',
'myo-hand-obj-hold': 'myoHandObjHoldFixed-v0',
'myo-hand-obj-hold-hard': 'myoHandObjHoldRandom-v0',
'myo-hand-key-turn': 'myoHandKeyTurnFixed-v0',
'myo-hand-key-turn-hard': 'myoHandKeyTurnRandom-v0',
'myo-hand-pen-twirl': 'myoHandPenTwirlFixed-v0',
'myo-hand-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
'myo-reach': 'myoHandReachFixed-v0',
'myo-reach-hard': 'myoHandReachRandom-v0',
'myo-pose': 'myoHandPoseFixed-v0',
'myo-pose-hard': 'myoHandPoseRandom-v0',
'myo-obj-hold': 'myoHandObjHoldFixed-v0',
'myo-obj-hold-hard': 'myoHandObjHoldRandom-v0',
'myo-key-turn': 'myoHandKeyTurnFixed-v0',
'myo-key-turn-hard': 'myoHandKeyTurnRandom-v0',
'myo-pen-twirl': 'myoHandPenTwirlFixed-v0',
'myo-pen-twirl-hard': 'myoHandPenTwirlRandom-v0',
}
@@ -50,7 +45,8 @@ def make_env(cfg):
Make Myosuite environment.
"""
if not cfg.task in MYOSUITE_TASKS:
raise UnknownTaskError(cfg.task)
raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
import myosuite
env = gym.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg)

View File

@@ -1,5 +1,6 @@
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ['LAZY_LEGACY_OP'] = '0'
import warnings
warnings.filterwarnings('ignore')
import torch

View File

@@ -50,11 +50,11 @@ class OnlineTrainer(Trainer):
def to_td(self, obs, action=None, reward=None):
"""Creates a TensorDict for a new episode."""
if isinstance(obs, dict):
obs = TensorDict({k: v.unsqueeze(0) for k,v in obs.items()}, batch_size=(1,)).cpu()
obs = TensorDict(obs, batch_size=(), device='cpu')
else:
obs = obs.unsqueeze(0).cpu()
if action is None:
action = torch.empty_like(self.env.rand_act())
action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None:
reward = torch.tensor(float('nan'))
td = TensorDict(dict(