Compare commits
12 Commits
main
...
uncertaint
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c03df676c | ||
|
|
8c299529a8 | ||
|
|
e96d4ae1a6 | ||
|
|
d28b03b3f9 | ||
|
|
614122644d | ||
|
|
dc39c23067 | ||
|
|
173131ca48 | ||
|
|
594299d7d1 | ||
|
|
188bd201aa | ||
|
|
392b16ac89 | ||
|
|
e5c9029c86 | ||
|
|
194c92331c |
41
README.md
41
README.md
@@ -2,13 +2,13 @@
|
||||
|
||||
Official implementation of
|
||||
|
||||
[TD-MPC2: Scalable, Robust World Models for Continuous Control](https://nicklashansen.github.io/td-mpc2) by
|
||||
[TD-MPC2: Scalable, Robust World Models for Continuous Control](https://www.tdmpc2.com) by
|
||||
|
||||
[Nicklas Hansen](https://nicklashansen.github.io/), [Hao Su](https://cseweb.ucsd.edu/~haosu/)\*, [Xiaolong Wang](https://xiaolonw.github.io/)\* (UC San Diego)</br>
|
||||
[Nicklas Hansen](https://nicklashansen.github.io), [Hao Su](https://cseweb.ucsd.edu/~haosu)\*, [Xiaolong Wang](https://xiaolonw.github.io)\* (UC San Diego)</br>
|
||||
|
||||
<img src="assets/0.gif" width="12.5%"><img src="assets/1.gif" width="12.5%"><img src="assets/2.gif" width="12.5%"><img src="assets/3.gif" width="12.5%"><img src="assets/4.gif" width="12.5%"><img src="assets/5.gif" width="12.5%"><img src="assets/6.gif" width="12.5%"><img src="assets/7.gif" width="12.5%"></br>
|
||||
|
||||
[[Website]](https://nicklashansen.github.io/td-mpc2) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://nicklashansen.github.io/td-mpc2/models) [[Dataset]](https://nicklashansen.github.io/td-mpc2/dataset)
|
||||
[[Website]](https://www.tdmpc2.com) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://www.tdmpc2.com/models) [[Dataset]](https://www.tdmpc2.com/dataset)
|
||||
|
||||
----
|
||||
|
||||
@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
|
||||
|
||||
<img src="assets/8.png" width="100%" style="max-width: 640px"><br/>
|
||||
|
||||
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
|
||||
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://www.tdmpc2.com/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://www.tdmpc2.com/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
|
||||
|
||||
----
|
||||
|
||||
@@ -29,17 +29,19 @@ You will need a machine with a GPU and at least 12 GB of RAM for single-task onl
|
||||
We provide a `Dockerfile` for easy installation. You can build the docker image by running
|
||||
|
||||
```
|
||||
cd docker && docker build . -t <user>/tdmpc2:0.1.0
|
||||
cd docker && docker build . -t <user>/tdmpc2:1.0.0
|
||||
```
|
||||
|
||||
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:
|
||||
This docker image contains all dependencies needed for running DMControl, Meta-World, and ManiSkill2 experiments.
|
||||
|
||||
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running the following command:
|
||||
|
||||
```
|
||||
conda env create -f docker/environment.yaml
|
||||
conda env create -f docker/environment_minimal.yaml
|
||||
pip install gym==0.21.0
|
||||
```
|
||||
|
||||
The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.
|
||||
The `environment.yaml` file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in `environment.yaml`.
|
||||
|
||||
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
|
||||
|
||||
@@ -118,16 +120,23 @@ We recommend using default hyperparameters for single-task online RL, including
|
||||
|
||||
## Citation
|
||||
|
||||
If you find our work useful, please consider citing the paper as follows:
|
||||
If you find our work useful, please consider citing our paper as follows:
|
||||
|
||||
```
|
||||
@misc{hansen2023tdmpc2,
|
||||
title={TD-MPC2: Scalable, Robust World Models for Continuous Control},
|
||||
author={Nicklas Hansen and Hao Su and Xiaolong Wang},
|
||||
year={2023},
|
||||
eprint={2310.16828},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG}
|
||||
@inproceedings{hansen2024tdmpc2,
|
||||
title={TD-MPC2: Scalable, Robust World Models for Continuous Control},
|
||||
author={Nicklas Hansen and Hao Su and Xiaolong Wang},
|
||||
booktitle={International Conference on Learning Representations (ICLR)},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
as well as the original TD-MPC paper:
|
||||
```
|
||||
@inproceedings{hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={International Conference on Machine Learning (ICML)},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
##########################################
|
||||
# Dockerfile for TD-MPC2 #
|
||||
# TD-MPC2 Anonymous Authors, 2023 (c) #
|
||||
# Nicklas Hansen, 2023 (c) #
|
||||
# https://www.tdmpc2.com #
|
||||
# -------------------------------------- #
|
||||
# Instructions: #
|
||||
# docker build . -t <user>/tdmpc2:0.1.0 #
|
||||
# docker push <user>/tdmpc2:0.1.0 #
|
||||
# Build instructions: #
|
||||
# docker build . -t <user>/tdmpc2:1.0.0 #
|
||||
# docker push <user>/tdmpc2:1.0.0 #
|
||||
# -------------------------------------- #
|
||||
# Run: #
|
||||
# docker run -i \ #
|
||||
# -v <path>/<to>/tdmpc2:/tdmpc2 \ #
|
||||
# --gpus all \ #
|
||||
# -t <user>/tdmpc2:1.0.0 \ #
|
||||
# /bin/bash #
|
||||
##########################################
|
||||
|
||||
# base image
|
||||
@@ -36,19 +44,15 @@ SHELL ["/bin/bash", "-c"]
|
||||
# conda environment
|
||||
COPY nvidia_icd.json /usr/share/vulkan/icd.d/nvidia_icd.json
|
||||
COPY environment.yaml /root
|
||||
RUN conda env update -n base -f /root/environment.yaml && \
|
||||
RUN conda update conda && \
|
||||
conda env update -n base -f /root/environment.yaml && \
|
||||
rm /root/environment.yaml && \
|
||||
cd /root && \
|
||||
python -m mani_skill2.utils.download_asset all -y && \
|
||||
conda clean -ya && \
|
||||
pip cache purge
|
||||
|
||||
# environment variables
|
||||
# mujoco 2.1.0
|
||||
ENV MUJOCO_GL egl
|
||||
ENV MS2_ASSET_DIR /root/data
|
||||
ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
|
||||
|
||||
# mujoco (required for metaworld)
|
||||
RUN mkdir -p /root/.mujoco && \
|
||||
wget https://www.tdmpc2.com/files/mjkey.txt && \
|
||||
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \
|
||||
@@ -56,4 +60,23 @@ RUN mkdir -p /root/.mujoco && \
|
||||
rm mujoco210-linux-x86_64.tar.gz && \
|
||||
mv mujoco210 /root/.mujoco/mujoco210 && \
|
||||
mv mjkey.txt /root/.mujoco/mjkey.txt && \
|
||||
find /root/.mujoco -uid 421709 -exec chown root:root {} \; && \
|
||||
python -c "import mujoco_py"
|
||||
|
||||
# gym
|
||||
RUN pip install gym==0.21.0
|
||||
|
||||
# metaworld
|
||||
RUN pip install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
||||
|
||||
# maniskill2
|
||||
ENV MS2_ASSET_DIR /root/data
|
||||
RUN pip install mani-skill2==0.4.1 && \
|
||||
cd /root && \
|
||||
python -m mani_skill2.utils.download_asset all -y
|
||||
|
||||
# myosuite (conflicts with meta-world / mani-skill2)
|
||||
# RUN pip install myosuite
|
||||
|
||||
# success!
|
||||
RUN echo "Successfully built TD-MPC2 Docker image!"
|
||||
|
||||
@@ -5,49 +5,62 @@ channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.0
|
||||
- pytorch
|
||||
- torchvision
|
||||
- cudatoolkit=11.7
|
||||
- glew
|
||||
- glib
|
||||
- pip==21
|
||||
- glew=2.1.0
|
||||
- glib=2.68.4
|
||||
- pip=21.0
|
||||
- python=3.9.0
|
||||
- pytorch>=2.2.2
|
||||
- torchvision>=0.16.2
|
||||
- pip:
|
||||
- absl-py
|
||||
- glfw
|
||||
- kornia
|
||||
- termcolor
|
||||
- gym==0.21.0
|
||||
- moviepy
|
||||
- ffmpeg
|
||||
- imageio
|
||||
- imageio-ffmpeg
|
||||
- omegaconf
|
||||
- hydra-core
|
||||
- hydra-submitit-launcher
|
||||
- submitit
|
||||
- pandas
|
||||
- patchelf
|
||||
- protobuf
|
||||
- tqdm
|
||||
- transforms3d
|
||||
- joblib
|
||||
- opencv-python
|
||||
- opencv-contrib-python
|
||||
- filelock
|
||||
- sapien==2.2.1
|
||||
- mani-skill2==0.4.1
|
||||
- trimesh
|
||||
- open3d
|
||||
- setuptools==65.5.0
|
||||
- absl-py==2.0.0
|
||||
- "cython<3"
|
||||
- dm-control==1.0.8
|
||||
- ffmpeg==1.4
|
||||
- glfw==2.6.4
|
||||
- hydra-core==1.3.2
|
||||
- hydra-submitit-launcher==1.2.0
|
||||
- imageio==2.33.1
|
||||
- imageio-ffmpeg==0.4.9
|
||||
- kornia==0.7.1
|
||||
- moviepy==1.0.3
|
||||
- mujoco==2.3.1
|
||||
- mujoco-py==2.1.2.14
|
||||
- dm-control
|
||||
- pillow
|
||||
- pyquaternion
|
||||
- git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
||||
# - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed
|
||||
- tensordict-nightly
|
||||
- torchrl-nightly
|
||||
- wandb
|
||||
- numpy==1.23.5
|
||||
- omegaconf==2.3.0
|
||||
- open3d==0.18.0
|
||||
- opencv-contrib-python==4.9.0.80
|
||||
- opencv-python==4.9.0.80
|
||||
- pandas==2.1.4
|
||||
- sapien==2.2.1
|
||||
- submitit==1.5.1
|
||||
- setuptools==65.5.0
|
||||
- patchelf==0.17.2.1
|
||||
- protobuf==4.25.2
|
||||
- pillow==10.2.0
|
||||
- pyquaternion==0.9.9
|
||||
- tensordict-nightly==2024.3.26
|
||||
- termcolor==2.4.0
|
||||
- torchrl-nightly==2024.3.26
|
||||
- transforms3d==0.4.1
|
||||
- trimesh==4.0.9
|
||||
- tqdm==4.66.1
|
||||
- wandb==0.16.2
|
||||
- wheel==0.38.0
|
||||
####################
|
||||
# Gym:
|
||||
# (unmaintained but required for maniskill2/meta-world/myosuite)
|
||||
# - gym==0.21.0
|
||||
####################
|
||||
# ManiSkill2:
|
||||
# (requires gym==0.21.0 which occasionally breaks)
|
||||
# - mani-skill2==0.4.1
|
||||
####################
|
||||
# Meta-World:
|
||||
# (requires gym==0.21.0 which occasionally breaks)
|
||||
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
||||
####################
|
||||
# MyoSuite:
|
||||
# (requires gym==0.13 which conflicts with meta-world / mani-skill2)
|
||||
# - myosuite
|
||||
####################
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
name: tdmpc2
|
||||
channels:
|
||||
- pytorch-nightly
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.0
|
||||
- pytorch
|
||||
- torchvision
|
||||
- cudatoolkit=11.7
|
||||
- glew
|
||||
- glib
|
||||
- pip==21
|
||||
- pip:
|
||||
- absl-py
|
||||
- glfw
|
||||
- kornia
|
||||
- termcolor
|
||||
- gym==0.21.0
|
||||
- moviepy
|
||||
- ffmpeg
|
||||
- imageio
|
||||
- imageio-ffmpeg
|
||||
- omegaconf
|
||||
- hydra-core
|
||||
- hydra-submitit-launcher
|
||||
- submitit
|
||||
- pandas
|
||||
- patchelf
|
||||
- protobuf
|
||||
- tqdm
|
||||
- setuptools==65.5.0
|
||||
- "cython<3"
|
||||
- dm-control
|
||||
- pillow
|
||||
- tensordict-nightly
|
||||
- torchrl-nightly
|
||||
- wandb
|
||||
@@ -1,8 +1,7 @@
|
||||
import torch
|
||||
from tensordict.tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
||||
|
||||
from common.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
|
||||
|
||||
class Buffer():
|
||||
@@ -20,6 +19,7 @@ class Buffer():
|
||||
end_key=None,
|
||||
traj_key='episode',
|
||||
truncated_key=None,
|
||||
strict_length=True,
|
||||
)
|
||||
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||
self._num_eps = 0
|
||||
|
||||
@@ -227,8 +227,10 @@ class Logger:
|
||||
xkey = "step"
|
||||
elif category == "pretrain":
|
||||
xkey = "iteration"
|
||||
_d = dict()
|
||||
for k, v in d.items():
|
||||
self._wandb.log({category + "/" + k: v}, step=d[xkey])
|
||||
_d[category + "/" + k] = v
|
||||
self._wandb.log(_d, step=d[xkey])
|
||||
if category == "eval" and self._save_csv:
|
||||
keys = ["step", "episode_reward"]
|
||||
self._eval.append(np.array([d[keys[0]], d[keys[1]]]))
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from tensordict.utils import NestedKey
|
||||
|
||||
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
|
||||
|
||||
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
|
||||
# This copy will live here until it has been included in a few torchrl stable releases
|
||||
|
||||
|
||||
class SliceSampler(Sampler):
|
||||
"""Samples slices of data along the first dimension, given start and stop signals.
|
||||
|
||||
This class samples sub-trajectories with replacement. For a version without
|
||||
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
|
||||
|
||||
Keyword Args:
|
||||
num_slices (int): the number of slices to be sampled. The batch-size
|
||||
must be greater or equal to the ``num_slices`` argument. Exclusive
|
||||
with ``slice_len``.
|
||||
slice_len (int): the length of the slices to be sampled. The batch-size
|
||||
must be greater or equal to the ``slice_len`` argument and divisible
|
||||
by it. Exclusive with ``num_slices``.
|
||||
end_key (NestedKey, optional): the key indicating the end of a
|
||||
trajectory (or episode). Defaults to ``("next", "done")``.
|
||||
traj_key (NestedKey, optional): the key indicating the trajectories.
|
||||
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
|
||||
cache_values (bool, optional): to be used with static datasets.
|
||||
Will cache the start and end signal of the trajectory.
|
||||
truncated_key (NestedKey, optional): If not ``None``, this argument
|
||||
indicates where a truncated signal should be written in the output
|
||||
data. This is used to indicate to value estimators where the provided
|
||||
trajectory breaks. Defaults to ``("next", "truncated")``.
|
||||
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
|
||||
instances (otherwise the truncated key is returned in the info dictionary
|
||||
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
|
||||
strict_length (bool, optional): if ``False``, trajectories of length
|
||||
shorter than `slice_len` (or `batch_size // num_slices`) will be
|
||||
allowed to appear in the batch.
|
||||
Be mindful that this can result in effective `batch_size` shorter
|
||||
than the one asked for! Trajectories can be split using
|
||||
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
|
||||
|
||||
.. note:: To recover the trajectory splits in the storage,
|
||||
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
|
||||
attempt to find the ``traj_key`` entry in the storage. If it cannot be
|
||||
found, the ``end_key`` will be used to reconstruct the episodes.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from tensordict import TensorDict
|
||||
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
>>> torch.manual_seed(0)
|
||||
>>> rb = TensorDictReplayBuffer(
|
||||
... storage=LazyMemmapStorage(1_000_000),
|
||||
... sampler=SliceSampler(cache_values=True, num_slices=10),
|
||||
... batch_size=320,
|
||||
... )
|
||||
>>> episode = torch.zeros(1000, dtype=torch.int)
|
||||
>>> episode[:300] = 1
|
||||
>>> episode[300:550] = 2
|
||||
>>> episode[550:700] = 3
|
||||
>>> episode[700:] = 4
|
||||
>>> data = TensorDict(
|
||||
... {
|
||||
... "episode": episode,
|
||||
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
|
||||
... "act": torch.randn((20,)).expand(1000, 20),
|
||||
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
|
||||
... }, [1000]
|
||||
... )
|
||||
>>> rb.extend(data)
|
||||
>>> sample = rb.sample()
|
||||
>>> print("sample:", sample)
|
||||
>>> print("episodes", sample.get("episode").unique())
|
||||
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
|
||||
|
||||
:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
|
||||
most of TorchRL's datasets:
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> from torchrl.data.datasets import RobosetExperienceReplay
|
||||
>>> from torchrl.data import SliceSampler
|
||||
>>>
|
||||
>>> torch.manual_seed(0)
|
||||
>>> num_slices = 10
|
||||
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
|
||||
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
|
||||
>>> for batch in data:
|
||||
... batch = batch.reshape(num_slices, -1)
|
||||
... break
|
||||
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
|
||||
check that each batch only has one episode: tensor([[19],
|
||||
[14],
|
||||
[ 8],
|
||||
[10],
|
||||
[13],
|
||||
[ 4],
|
||||
[ 2],
|
||||
[ 3],
|
||||
[22],
|
||||
[ 8]])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
end_key: NestedKey | None = None,
|
||||
traj_key: NestedKey | None = None,
|
||||
cache_values: bool = False,
|
||||
truncated_key: NestedKey | None = ("next", "truncated"),
|
||||
strict_length: bool = True,
|
||||
) -> object:
|
||||
if end_key is None:
|
||||
end_key = ("next", "done")
|
||||
if traj_key is None:
|
||||
traj_key = "episode"
|
||||
if not ((num_slices is None) ^ (slice_len is None)):
|
||||
raise TypeError(
|
||||
"Either num_slices or slice_len must be not None, and not both. "
|
||||
f"Got num_slices={num_slices} and slice_len={slice_len}."
|
||||
)
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.end_key = end_key
|
||||
self.traj_key = traj_key
|
||||
self.truncated_key = truncated_key
|
||||
self.cache_values = cache_values
|
||||
self._fetch_traj = True
|
||||
self._uses_data_prefix = False
|
||||
self.strict_length = strict_length
|
||||
self._cache = {}
|
||||
|
||||
@staticmethod
|
||||
def _find_start_stop_traj(*, trajectory=None, end=None):
|
||||
if trajectory is not None:
|
||||
# slower
|
||||
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
|
||||
# stop_idx = stop_idx.cumsum(0) - 1
|
||||
|
||||
# even slower
|
||||
# t = trajectory.unsqueeze(0)
|
||||
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
|
||||
# stop_idx = torch.conv1d(t, w).nonzero()
|
||||
|
||||
# faster
|
||||
end = trajectory[:-1] != trajectory[1:]
|
||||
end = torch.cat([end, torch.ones_like(end[:1])], 0)
|
||||
else:
|
||||
end = torch.index_fill(
|
||||
end,
|
||||
index=torch.tensor(-1, device=end.device, dtype=torch.long),
|
||||
dim=0,
|
||||
value=1,
|
||||
)
|
||||
if end.ndim != 1:
|
||||
raise RuntimeError(
|
||||
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
|
||||
)
|
||||
stop_idx = end.view(-1).nonzero().view(-1)
|
||||
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
|
||||
lengths = stop_idx - start_idx + 1
|
||||
return start_idx, stop_idx, lengths
|
||||
|
||||
def _tensor_slices_from_startend(self, seq_length, start):
|
||||
if isinstance(seq_length, int):
|
||||
return (
|
||||
torch.arange(
|
||||
seq_length, device=start.device, dtype=start.dtype
|
||||
).unsqueeze(0)
|
||||
+ start.unsqueeze(1)
|
||||
).view(-1)
|
||||
else:
|
||||
# when padding is needed
|
||||
return torch.cat(
|
||||
[
|
||||
_start
|
||||
+ torch.arange(_seq_len, device=start.device, dtype=start.dtype)
|
||||
for _start, _seq_len in zip(start, seq_length)
|
||||
]
|
||||
)
|
||||
|
||||
def _get_stop_and_length(self, storage, fallback=True):
|
||||
if self.cache_values and "stop-and-length" in self._cache:
|
||||
return self._cache.get("stop-and-length")
|
||||
|
||||
if self._fetch_traj:
|
||||
# We first try with the traj_key
|
||||
try:
|
||||
# In some cases, the storage hides the data behind "_data".
|
||||
# In the future, this may be deprecated, and we don't want to mess
|
||||
# with the keys provided by the user so we fall back on a proxy to
|
||||
# the traj key.
|
||||
try:
|
||||
trajectory = storage._storage.get(self._used_traj_key)
|
||||
except KeyError:
|
||||
trajectory = storage._storage.get(("_data", self.traj_key))
|
||||
# cache that value for future use
|
||||
self._used_traj_key = ("_data", self.traj_key)
|
||||
self._uses_data_prefix = (
|
||||
isinstance(self._used_traj_key, tuple)
|
||||
and self._used_traj_key[0] == "_data"
|
||||
)
|
||||
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
|
||||
if self.cache_values:
|
||||
self._cache["stop-and-length"] = vals
|
||||
return vals
|
||||
except KeyError:
|
||||
if fallback:
|
||||
self._fetch_traj = False
|
||||
return self._get_stop_and_length(storage, fallback=False)
|
||||
raise
|
||||
|
||||
else:
|
||||
try:
|
||||
# In some cases, the storage hides the data behind "_data".
|
||||
# In the future, this may be deprecated, and we don't want to mess
|
||||
# with the keys provided by the user so we fall back on a proxy to
|
||||
# the traj key.
|
||||
try:
|
||||
done = storage._storage.get(self._used_end_key)
|
||||
except KeyError:
|
||||
done = storage._storage.get(("_data", self.end_key))
|
||||
# cache that value for future use
|
||||
self._used_end_key = ("_data", self.end_key)
|
||||
self._uses_data_prefix = (
|
||||
isinstance(self._used_end_key, tuple)
|
||||
and self._used_end_key[0] == "_data"
|
||||
)
|
||||
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
|
||||
if self.cache_values:
|
||||
self._cache["stop-and-length"] = vals
|
||||
return vals
|
||||
except KeyError:
|
||||
if fallback:
|
||||
self._fetch_traj = True
|
||||
return self._get_stop_and_length(storage, fallback=False)
|
||||
raise
|
||||
|
||||
def _adjusted_batch_size(self, batch_size):
|
||||
if self.num_slices is not None:
|
||||
if batch_size % self.num_slices != 0:
|
||||
raise RuntimeError(
|
||||
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
|
||||
)
|
||||
seq_length = batch_size // self.num_slices
|
||||
num_slices = self.num_slices
|
||||
else:
|
||||
if batch_size % self.slice_len != 0:
|
||||
raise RuntimeError(
|
||||
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
|
||||
)
|
||||
seq_length = self.slice_len
|
||||
num_slices = batch_size // self.slice_len
|
||||
return seq_length, num_slices
|
||||
|
||||
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
|
||||
if not isinstance(storage, TensorStorage):
|
||||
raise RuntimeError(
|
||||
f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
|
||||
)
|
||||
|
||||
# pick up as many trajs as we need
|
||||
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
|
||||
seq_length, num_slices = self._adjusted_batch_size(batch_size)
|
||||
return self._sample_slices(lengths, start_idx, seq_length, num_slices)
|
||||
|
||||
def _sample_slices(
|
||||
self, lengths, start_idx, seq_length, num_slices, traj_idx=None
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
if (lengths < seq_length).any():
|
||||
if self.strict_length:
|
||||
raise RuntimeError(
|
||||
"Some stored trajectories have a length shorter than the slice that was asked for. "
|
||||
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
|
||||
"in you batch."
|
||||
)
|
||||
# make seq_length a tensor with values clamped by lengths
|
||||
seq_length = lengths.clamp_max(seq_length)
|
||||
|
||||
if traj_idx is None:
|
||||
traj_idx = torch.randint(
|
||||
lengths.shape[0], (num_slices,), device=lengths.device
|
||||
)
|
||||
else:
|
||||
num_slices = traj_idx.shape[0]
|
||||
relative_starts = (
|
||||
(
|
||||
torch.rand(num_slices, device=lengths.device)
|
||||
* (lengths[traj_idx] - seq_length + 1)
|
||||
)
|
||||
.floor()
|
||||
.to(start_idx.dtype)
|
||||
)
|
||||
starts = start_idx[traj_idx] + relative_starts
|
||||
index = self._tensor_slices_from_startend(seq_length, starts)
|
||||
if self.truncated_key is not None:
|
||||
truncated_key = self.truncated_key
|
||||
|
||||
truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
|
||||
if isinstance(seq_length, int):
|
||||
truncated.view(num_slices, -1)[:, -1] = 1
|
||||
else:
|
||||
truncated[seq_length.cumsum(0) - 1] = 1
|
||||
return index.to(torch.long), {truncated_key: truncated}
|
||||
return index.to(torch.long), {}
|
||||
|
||||
@property
|
||||
def _used_traj_key(self):
|
||||
return self.__dict__.get("__used_traj_key", self.traj_key)
|
||||
|
||||
@_used_traj_key.setter
|
||||
def _used_traj_key(self, value):
|
||||
self.__dict__["__used_traj_key"] = value
|
||||
|
||||
@property
|
||||
def _used_end_key(self):
|
||||
return self.__dict__.get("__used_end_key", self.end_key)
|
||||
|
||||
@_used_end_key.setter
|
||||
def _used_end_key(self, value):
|
||||
self.__dict__["__used_end_key"] = value
|
||||
|
||||
def _empty(self):
|
||||
pass
|
||||
|
||||
def dumps(self, path):
|
||||
# no op - cache does not need to be saved
|
||||
...
|
||||
|
||||
def loads(self, path):
|
||||
# no op
|
||||
...
|
||||
|
||||
def __getstate__(self):
|
||||
state = copy(self.__dict__)
|
||||
state["_cache"] = {}
|
||||
return state
|
||||
@@ -38,6 +38,7 @@ horizon: 3
|
||||
min_std: 0.05
|
||||
max_std: 2
|
||||
temperature: 0.5
|
||||
uncertainty_coef: 0
|
||||
|
||||
# actor
|
||||
log_std_min: -10
|
||||
|
||||
@@ -44,7 +44,7 @@ def evaluate(cfg: dict):
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold']))
|
||||
print(colored(f'Model size: {cfg.model_size}', 'blue', attrs=['bold']))
|
||||
print(colored(f'Model size: {cfg.get("model_size", "default")}', 'blue', attrs=['bold']))
|
||||
print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold']))
|
||||
if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint):
|
||||
print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold']))
|
||||
|
||||
@@ -90,6 +90,14 @@ class TDMPC2:
|
||||
else:
|
||||
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
||||
return a.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def _estimate_uncertainty(self, z, task):
|
||||
"""Estimates epistemic uncertainty, normalized by predicted value."""
|
||||
if self.cfg.uncertainty_coef == 0:
|
||||
return 0
|
||||
qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg)
|
||||
return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef
|
||||
|
||||
@torch.no_grad()
|
||||
def _estimate_value(self, z, actions, task):
|
||||
@@ -98,9 +106,10 @@ class TDMPC2:
|
||||
for t in range(self.cfg.horizon):
|
||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||
z = self.model.next(z, actions[t], task)
|
||||
G += discount * reward
|
||||
G += discount * (reward - self._estimate_uncertainty(z, task))
|
||||
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||
terminal_value = self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||
return G + discount * (terminal_value - self._estimate_uncertainty(z, task))
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z, t0=False, eval_mode=False, task=None):
|
||||
|
||||
@@ -7,8 +7,8 @@ class Trainer:
|
||||
self.agent = agent
|
||||
self.buffer = buffer
|
||||
self.logger = logger
|
||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
||||
print('Architecture:', self.agent.model)
|
||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate a TD-MPC2 agent."""
|
||||
|
||||
Reference in New Issue
Block a user