12 Commits

Author SHA1 Message Date
Nicklas Hansen
4c03df676c update pinned torchrl version 2024-07-02 10:12:30 -07:00
Nicklas Hansen
8c299529a8 Update README.md 2024-07-02 10:12:30 -07:00
Nicklas Hansen
e96d4ae1a6 reduce # wandb calls 2024-07-02 10:12:30 -07:00
Nicklas Hansen
d28b03b3f9 update dockerfile 2024-07-02 10:12:30 -07:00
Nicklas Hansen
614122644d update dockerfile + pin all versions 2024-07-02 10:12:30 -07:00
Nicklas Hansen
dc39c23067 minor fix in print 2024-07-02 10:12:30 -07:00
Nicklas Hansen
173131ca48 migrate to slicebuffer from torchrl-nightly 2024-07-02 10:12:30 -07:00
Nicklas Hansen
594299d7d1 Merge branch 'uncertainty-regularization' of github.com:nicklashansen/tdmpc2 into uncertainty-regularization 2024-01-08 11:00:17 -08:00
Nicklas Hansen
188bd201aa disable uncertainty estimation when coef=0 2024-01-08 10:55:46 -08:00
Nicklas Hansen
392b16ac89 add uncertainty regularization 2024-01-08 10:55:46 -08:00
Nicklas Hansen
e5c9029c86 disable uncertainty estimation when coef=0 2024-01-04 19:39:44 -08:00
Nicklas Hansen
194c92331c add uncertainty regularization 2024-01-03 18:11:32 -08:00
11 changed files with 132 additions and 465 deletions

View File

@@ -2,13 +2,13 @@
Official implementation of Official implementation of
[TD-MPC2: Scalable, Robust World Models for Continuous Control](https://nicklashansen.github.io/td-mpc2) by [TD-MPC2: Scalable, Robust World Models for Continuous Control](https://www.tdmpc2.com) by
[Nicklas Hansen](https://nicklashansen.github.io/), [Hao Su](https://cseweb.ucsd.edu/~haosu/)\*, [Xiaolong Wang](https://xiaolonw.github.io/)\* (UC San Diego)</br> [Nicklas Hansen](https://nicklashansen.github.io), [Hao Su](https://cseweb.ucsd.edu/~haosu)\*, [Xiaolong Wang](https://xiaolonw.github.io)\* (UC San Diego)</br>
<img src="assets/0.gif" width="12.5%"><img src="assets/1.gif" width="12.5%"><img src="assets/2.gif" width="12.5%"><img src="assets/3.gif" width="12.5%"><img src="assets/4.gif" width="12.5%"><img src="assets/5.gif" width="12.5%"><img src="assets/6.gif" width="12.5%"><img src="assets/7.gif" width="12.5%"></br> <img src="assets/0.gif" width="12.5%"><img src="assets/1.gif" width="12.5%"><img src="assets/2.gif" width="12.5%"><img src="assets/3.gif" width="12.5%"><img src="assets/4.gif" width="12.5%"><img src="assets/5.gif" width="12.5%"><img src="assets/6.gif" width="12.5%"><img src="assets/7.gif" width="12.5%"></br>
[[Website]](https://nicklashansen.github.io/td-mpc2) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://nicklashansen.github.io/td-mpc2/models) [[Dataset]](https://nicklashansen.github.io/td-mpc2/dataset) [[Website]](https://www.tdmpc2.com) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://www.tdmpc2.com/models) [[Dataset]](https://www.tdmpc2.com/dataset)
---- ----
@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
<img src="assets/8.png" width="100%" style="max-width: 640px"><br/> <img src="assets/8.png" width="100%" style="max-width: 640px"><br/>
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL. This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://www.tdmpc2.com/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://www.tdmpc2.com/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
---- ----
@@ -29,17 +29,19 @@ You will need a machine with a GPU and at least 12 GB of RAM for single-task onl
We provide a `Dockerfile` for easy installation. You can build the docker image by running We provide a `Dockerfile` for easy installation. You can build the docker image by running
``` ```
cd docker && docker build . -t <user>/tdmpc2:0.1.0 cd docker && docker build . -t <user>/tdmpc2:1.0.0
``` ```
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands: This docker image contains all dependencies needed for running DMControl, Meta-World, and ManiSkill2 experiments.
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running the following command:
``` ```
conda env create -f docker/environment.yaml conda env create -f docker/environment.yaml
conda env create -f docker/environment_minimal.yaml pip install gym==0.21.0
``` ```
The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks. The `environment.yaml` file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in `environment.yaml`.
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
@@ -118,16 +120,23 @@ We recommend using default hyperparameters for single-task online RL, including
## Citation ## Citation
If you find our work useful, please consider citing the paper as follows: If you find our work useful, please consider citing our paper as follows:
``` ```
@misc{hansen2023tdmpc2, @inproceedings{hansen2024tdmpc2,
title={TD-MPC2: Scalable, Robust World Models for Continuous Control}, title={TD-MPC2: Scalable, Robust World Models for Continuous Control},
author={Nicklas Hansen and Hao Su and Xiaolong Wang}, author={Nicklas Hansen and Hao Su and Xiaolong Wang},
year={2023}, booktitle={International Conference on Learning Representations (ICLR)},
eprint={2310.16828}, year={2024}
archivePrefix={arXiv}, }
primaryClass={cs.LG} ```
as well as the original TD-MPC paper:
```
@inproceedings{hansen2022tdmpc,
title={Temporal Difference Learning for Model Predictive Control},
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
booktitle={International Conference on Machine Learning (ICML)},
year={2022}
} }
``` ```

View File

@@ -1,10 +1,18 @@
########################################## ##########################################
# Dockerfile for TD-MPC2 # # Dockerfile for TD-MPC2 #
# TD-MPC2 Anonymous Authors, 2023 (c) # # Nicklas Hansen, 2023 (c) #
# https://www.tdmpc2.com #
# -------------------------------------- # # -------------------------------------- #
# Instructions: # # Build instructions: #
# docker build . -t <user>/tdmpc2:0.1.0 # # docker build . -t <user>/tdmpc2:1.0.0 #
# docker push <user>/tdmpc2:0.1.0 # # docker push <user>/tdmpc2:1.0.0 #
# -------------------------------------- #
# Run: #
# docker run -i \ #
# -v <path>/<to>/tdmpc2:/tdmpc2 \ #
# --gpus all \ #
# -t <user>/tdmpc2:1.0.0 \ #
# /bin/bash #
########################################## ##########################################
# base image # base image
@@ -36,19 +44,15 @@ SHELL ["/bin/bash", "-c"]
# conda environment # conda environment
COPY nvidia_icd.json /usr/share/vulkan/icd.d/nvidia_icd.json COPY nvidia_icd.json /usr/share/vulkan/icd.d/nvidia_icd.json
COPY environment.yaml /root COPY environment.yaml /root
RUN conda env update -n base -f /root/environment.yaml && \ RUN conda update conda && \
conda env update -n base -f /root/environment.yaml && \
rm /root/environment.yaml && \ rm /root/environment.yaml && \
cd /root && \
python -m mani_skill2.utils.download_asset all -y && \
conda clean -ya && \ conda clean -ya && \
pip cache purge pip cache purge
# environment variables # mujoco 2.1.0
ENV MUJOCO_GL egl ENV MUJOCO_GL egl
ENV MS2_ASSET_DIR /root/data
ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH} ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
# mujoco (required for metaworld)
RUN mkdir -p /root/.mujoco && \ RUN mkdir -p /root/.mujoco && \
wget https://www.tdmpc2.com/files/mjkey.txt && \ wget https://www.tdmpc2.com/files/mjkey.txt && \
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \ wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \
@@ -56,4 +60,23 @@ RUN mkdir -p /root/.mujoco && \
rm mujoco210-linux-x86_64.tar.gz && \ rm mujoco210-linux-x86_64.tar.gz && \
mv mujoco210 /root/.mujoco/mujoco210 && \ mv mujoco210 /root/.mujoco/mujoco210 && \
mv mjkey.txt /root/.mujoco/mjkey.txt && \ mv mjkey.txt /root/.mujoco/mjkey.txt && \
find /root/.mujoco -uid 421709 -exec chown root:root {} \; && \
python -c "import mujoco_py" python -c "import mujoco_py"
# gym
RUN pip install gym==0.21.0
# metaworld
RUN pip install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
# maniskill2
ENV MS2_ASSET_DIR /root/data
RUN pip install mani-skill2==0.4.1 && \
cd /root && \
python -m mani_skill2.utils.download_asset all -y
# myosuite (conflicts with meta-world / mani-skill2)
# RUN pip install myosuite
# success!
RUN echo "Successfully built TD-MPC2 Docker image!"

View File

@@ -5,49 +5,62 @@ channels:
- conda-forge - conda-forge
- defaults - defaults
dependencies: dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7 - cudatoolkit=11.7
- glew - glew=2.1.0
- glib - glib=2.68.4
- pip==21 - pip=21.0
- python=3.9.0
- pytorch>=2.2.2
- torchvision>=0.16.2
- pip: - pip:
- absl-py - absl-py==2.0.0
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- transforms3d
- joblib
- opencv-python
- opencv-contrib-python
- filelock
- sapien==2.2.1
- mani-skill2==0.4.1
- trimesh
- open3d
- setuptools==65.5.0
- "cython<3" - "cython<3"
- dm-control==1.0.8
- ffmpeg==1.4
- glfw==2.6.4
- hydra-core==1.3.2
- hydra-submitit-launcher==1.2.0
- imageio==2.33.1
- imageio-ffmpeg==0.4.9
- kornia==0.7.1
- moviepy==1.0.3
- mujoco==2.3.1 - mujoco==2.3.1
- mujoco-py==2.1.2.14 - mujoco-py==2.1.2.14
- dm-control - numpy==1.23.5
- pillow - omegaconf==2.3.0
- pyquaternion - open3d==0.18.0
- git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb - opencv-contrib-python==4.9.0.80
# - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed - opencv-python==4.9.0.80
- tensordict-nightly - pandas==2.1.4
- torchrl-nightly - sapien==2.2.1
- wandb - submitit==1.5.1
- setuptools==65.5.0
- patchelf==0.17.2.1
- protobuf==4.25.2
- pillow==10.2.0
- pyquaternion==0.9.9
- tensordict-nightly==2024.3.26
- termcolor==2.4.0
- torchrl-nightly==2024.3.26
- transforms3d==0.4.1
- trimesh==4.0.9
- tqdm==4.66.1
- wandb==0.16.2
- wheel==0.38.0
####################
# Gym:
# (unmaintained but required for maniskill2/meta-world/myosuite)
# - gym==0.21.0
####################
# ManiSkill2:
# (requires gym==0.21.0 which occasionally breaks)
# - mani-skill2==0.4.1
####################
# Meta-World:
# (requires gym==0.21.0 which occasionally breaks)
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
####################
# MyoSuite:
# (requires gym==0.13 which conflicts with meta-world / mani-skill2)
# - myosuite
####################

View File

@@ -1,39 +0,0 @@
name: tdmpc2
channels:
- pytorch-nightly
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7
- glew
- glib
- pip==21
- pip:
- absl-py
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- setuptools==65.5.0
- "cython<3"
- dm-control
- pillow
- tensordict-nightly
- torchrl-nightly
- wandb

View File

@@ -1,8 +1,7 @@
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler
from common.samplers import SliceSampler
class Buffer(): class Buffer():
@@ -20,6 +19,7 @@ class Buffer():
end_key=None, end_key=None,
traj_key='episode', traj_key='episode',
truncated_key=None, truncated_key=None,
strict_length=True,
) )
self._batch_size = cfg.batch_size * (cfg.horizon+1) self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0 self._num_eps = 0

View File

@@ -227,8 +227,10 @@ class Logger:
xkey = "step" xkey = "step"
elif category == "pretrain": elif category == "pretrain":
xkey = "iteration" xkey = "iteration"
_d = dict()
for k, v in d.items(): for k, v in d.items():
self._wandb.log({category + "/" + k: v}, step=d[xkey]) _d[category + "/" + k] = v
self._wandb.log(_d, step=d[xkey])
if category == "eval" and self._save_csv: if category == "eval" and self._save_csv:
keys = ["step", "episode_reward"] keys = ["step", "episode_reward"]
self._eval.append(np.array([d[keys[0]], d[keys[1]]])) self._eval.append(np.array([d[keys[0]], d[keys[1]]]))

View File

@@ -1,351 +0,0 @@
from __future__ import annotations
from copy import copy
from typing import Tuple
import torch
from tensordict.utils import NestedKey
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.samplers import Sampler
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
# This copy will live here until it has been included in a few torchrl stable releases
class SliceSampler(Sampler):
"""Samples slices of data along the first dimension, given start and stop signals.
This class samples sub-trajectories with replacement. For a version without
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
with ``slice_len``.
slice_len (int): the length of the slices to be sampled. The batch-size
must be greater or equal to the ``slice_len`` argument and divisible
by it. Exclusive with ``num_slices``.
end_key (NestedKey, optional): the key indicating the end of a
trajectory (or episode). Defaults to ``("next", "done")``.
traj_key (NestedKey, optional): the key indicating the trajectories.
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
trajectory breaks. Defaults to ``("next", "truncated")``.
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
instances (otherwise the truncated key is returned in the info dictionary
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
attempt to find the ``traj_key`` entry in the storage. If it cannot be
found, the ``end_key`` will be used to reconstruct the episodes.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
... storage=LazyMemmapStorage(1_000_000),
... sampler=SliceSampler(cache_values=True, num_slices=10),
... batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
... {
... "episode": episode,
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
... "act": torch.randn((20,)).expand(1000, 20),
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
... }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
most of TorchRL's datasets:
Examples:
>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
... batch = batch.reshape(num_slices, -1)
... break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
[14],
[ 8],
[10],
[13],
[ 4],
[ 2],
[ 3],
[22],
[ 8]])
"""
def __init__(
self,
*,
num_slices: int = None,
slice_len: int = None,
end_key: NestedKey | None = None,
traj_key: NestedKey | None = None,
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
) -> object:
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
traj_key = "episode"
if not ((num_slices is None) ^ (slice_len is None)):
raise TypeError(
"Either num_slices or slice_len must be not None, and not both. "
f"Got num_slices={num_slices} and slice_len={slice_len}."
)
self.num_slices = num_slices
self.slice_len = slice_len
self.end_key = end_key
self.traj_key = traj_key
self.truncated_key = truncated_key
self.cache_values = cache_values
self._fetch_traj = True
self._uses_data_prefix = False
self.strict_length = strict_length
self._cache = {}
@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
# stop_idx = stop_idx.cumsum(0) - 1
# even slower
# t = trajectory.unsqueeze(0)
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
# stop_idx = torch.conv1d(t, w).nonzero()
# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
if end.ndim != 1:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
)
stop_idx = end.view(-1).nonzero().view(-1)
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
lengths = stop_idx - start_idx + 1
return start_idx, stop_idx, lengths
def _tensor_slices_from_startend(self, seq_length, start):
if isinstance(seq_length, int):
return (
torch.arange(
seq_length, device=start.device, dtype=start.dtype
).unsqueeze(0)
+ start.unsqueeze(1)
).view(-1)
else:
# when padding is needed
return torch.cat(
[
_start
+ torch.arange(_seq_len, device=start.device, dtype=start.dtype)
for _start, _seq_len in zip(start, seq_length)
]
)
def _get_stop_and_length(self, storage, fallback=True):
if self.cache_values and "stop-and-length" in self._cache:
return self._cache.get("stop-and-length")
if self._fetch_traj:
# We first try with the traj_key
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
trajectory = storage._storage.get(self._used_traj_key)
except KeyError:
trajectory = storage._storage.get(("_data", self.traj_key))
# cache that value for future use
self._used_traj_key = ("_data", self.traj_key)
self._uses_data_prefix = (
isinstance(self._used_traj_key, tuple)
and self._used_traj_key[0] == "_data"
)
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = False
return self._get_stop_and_length(storage, fallback=False)
raise
else:
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
done = storage._storage.get(self._used_end_key)
except KeyError:
done = storage._storage.get(("_data", self.end_key))
# cache that value for future use
self._used_end_key = ("_data", self.end_key)
self._uses_data_prefix = (
isinstance(self._used_end_key, tuple)
and self._used_end_key[0] == "_data"
)
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = True
return self._get_stop_and_length(storage, fallback=False)
raise
def _adjusted_batch_size(self, batch_size):
if self.num_slices is not None:
if batch_size % self.num_slices != 0:
raise RuntimeError(
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
)
seq_length = batch_size // self.num_slices
num_slices = self.num_slices
else:
if batch_size % self.slice_len != 0:
raise RuntimeError(
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
)
seq_length = self.slice_len
num_slices = batch_size // self.slice_len
return seq_length, num_slices
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if not isinstance(storage, TensorStorage):
raise RuntimeError(
f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
)
# pick up as many trajs as we need
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
return self._sample_slices(lengths, start_idx, seq_length, num_slices)
def _sample_slices(
self, lengths, start_idx, seq_length, num_slices, traj_idx=None
) -> Tuple[torch.Tensor, dict]:
if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for. "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
# make seq_length a tensor with values clamped by lengths
seq_length = lengths.clamp_max(seq_length)
if traj_idx is None:
traj_idx = torch.randint(
lengths.shape[0], (num_slices,), device=lengths.device
)
else:
num_slices = traj_idx.shape[0]
relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
* (lengths[traj_idx] - seq_length + 1)
)
.floor()
.to(start_idx.dtype)
)
starts = start_idx[traj_idx] + relative_starts
index = self._tensor_slices_from_startend(seq_length, starts)
if self.truncated_key is not None:
truncated_key = self.truncated_key
truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
return index.to(torch.long), {truncated_key: truncated}
return index.to(torch.long), {}
@property
def _used_traj_key(self):
return self.__dict__.get("__used_traj_key", self.traj_key)
@_used_traj_key.setter
def _used_traj_key(self, value):
self.__dict__["__used_traj_key"] = value
@property
def _used_end_key(self):
return self.__dict__.get("__used_end_key", self.end_key)
@_used_end_key.setter
def _used_end_key(self, value):
self.__dict__["__used_end_key"] = value
def _empty(self):
pass
def dumps(self, path):
# no op - cache does not need to be saved
...
def loads(self, path):
# no op
...
def __getstate__(self):
state = copy(self.__dict__)
state["_cache"] = {}
return state

View File

@@ -38,6 +38,7 @@ horizon: 3
min_std: 0.05 min_std: 0.05
max_std: 2 max_std: 2
temperature: 0.5 temperature: 0.5
uncertainty_coef: 0
# actor # actor
log_std_min: -10 log_std_min: -10

View File

@@ -44,7 +44,7 @@ def evaluate(cfg: dict):
cfg = parse_cfg(cfg) cfg = parse_cfg(cfg)
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold'])) print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold']))
print(colored(f'Model size: {cfg.model_size}', 'blue', attrs=['bold'])) print(colored(f'Model size: {cfg.get("model_size", "default")}', 'blue', attrs=['bold']))
print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold'])) print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold']))
if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint): if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint):
print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold'])) print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold']))

View File

@@ -91,6 +91,14 @@ class TDMPC2:
a = self.model.pi(z, task)[int(not eval_mode)][0] a = self.model.pi(z, task)[int(not eval_mode)][0]
return a.cpu() return a.cpu()
@torch.no_grad()
def _estimate_uncertainty(self, z, task):
"""Estimates epistemic uncertainty, normalized by predicted value."""
if self.cfg.uncertainty_coef == 0:
return 0
qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg)
return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef
@torch.no_grad() @torch.no_grad()
def _estimate_value(self, z, actions, task): def _estimate_value(self, z, actions, task):
"""Estimate value of a trajectory starting at latent state z and executing given actions.""" """Estimate value of a trajectory starting at latent state z and executing given actions."""
@@ -98,9 +106,10 @@ class TDMPC2:
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task) z = self.model.next(z, actions[t], task)
G += discount * reward G += discount * (reward - self._estimate_uncertainty(z, task))
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') terminal_value = self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
return G + discount * (terminal_value - self._estimate_uncertainty(z, task))
@torch.no_grad() @torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None): def plan(self, z, t0=False, eval_mode=False, task=None):

View File

@@ -7,8 +7,8 @@ class Trainer:
self.agent = agent self.agent = agent
self.buffer = buffer self.buffer = buffer
self.logger = logger self.logger = logger
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
print('Architecture:', self.agent.model) print('Architecture:', self.agent.model)
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
def eval(self): def eval(self):
"""Evaluate a TD-MPC2 agent.""" """Evaluate a TD-MPC2 agent."""