12 Commits

Author SHA1 Message Date
Nicklas Hansen
4c03df676c update pinned torchrl version 2024-07-02 10:12:30 -07:00
Nicklas Hansen
8c299529a8 Update README.md 2024-07-02 10:12:30 -07:00
Nicklas Hansen
e96d4ae1a6 reduce # wandb calls 2024-07-02 10:12:30 -07:00
Nicklas Hansen
d28b03b3f9 update dockerfile 2024-07-02 10:12:30 -07:00
Nicklas Hansen
614122644d update dockerfile + pin all versions 2024-07-02 10:12:30 -07:00
Nicklas Hansen
dc39c23067 minor fix in print 2024-07-02 10:12:30 -07:00
Nicklas Hansen
173131ca48 migrate to slicebuffer from torchrl-nightly 2024-07-02 10:12:30 -07:00
Nicklas Hansen
594299d7d1 Merge branch 'uncertainty-regularization' of github.com:nicklashansen/tdmpc2 into uncertainty-regularization 2024-01-08 11:00:17 -08:00
Nicklas Hansen
188bd201aa disable uncertainty estimation when coef=0 2024-01-08 10:55:46 -08:00
Nicklas Hansen
392b16ac89 add uncertainty regularization 2024-01-08 10:55:46 -08:00
Nicklas Hansen
e5c9029c86 disable uncertainty estimation when coef=0 2024-01-04 19:39:44 -08:00
Nicklas Hansen
194c92331c add uncertainty regularization 2024-01-03 18:11:32 -08:00
11 changed files with 132 additions and 465 deletions

View File

@@ -2,13 +2,13 @@
Official implementation of
[TD-MPC2: Scalable, Robust World Models for Continuous Control](https://nicklashansen.github.io/td-mpc2) by
[TD-MPC2: Scalable, Robust World Models for Continuous Control](https://www.tdmpc2.com) by
[Nicklas Hansen](https://nicklashansen.github.io/), [Hao Su](https://cseweb.ucsd.edu/~haosu/)\*, [Xiaolong Wang](https://xiaolonw.github.io/)\* (UC San Diego)</br>
[Nicklas Hansen](https://nicklashansen.github.io), [Hao Su](https://cseweb.ucsd.edu/~haosu)\*, [Xiaolong Wang](https://xiaolonw.github.io)\* (UC San Diego)</br>
<img src="assets/0.gif" width="12.5%"><img src="assets/1.gif" width="12.5%"><img src="assets/2.gif" width="12.5%"><img src="assets/3.gif" width="12.5%"><img src="assets/4.gif" width="12.5%"><img src="assets/5.gif" width="12.5%"><img src="assets/6.gif" width="12.5%"><img src="assets/7.gif" width="12.5%"></br>
[[Website]](https://nicklashansen.github.io/td-mpc2) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://nicklashansen.github.io/td-mpc2/models) [[Dataset]](https://nicklashansen.github.io/td-mpc2/dataset)
[[Website]](https://www.tdmpc2.com) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://www.tdmpc2.com/models) [[Dataset]](https://www.tdmpc2.com/dataset)
----
@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
<img src="assets/8.png" width="100%" style="max-width: 640px"><br/>
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://www.tdmpc2.com/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://www.tdmpc2.com/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
----
@@ -29,17 +29,19 @@ You will need a machine with a GPU and at least 12 GB of RAM for single-task onl
We provide a `Dockerfile` for easy installation. You can build the docker image by running
```
cd docker && docker build . -t <user>/tdmpc2:0.1.0
cd docker && docker build . -t <user>/tdmpc2:1.0.0
```
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:
This docker image contains all dependencies needed for running DMControl, Meta-World, and ManiSkill2 experiments.
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running the following command:
```
conda env create -f docker/environment.yaml
conda env create -f docker/environment_minimal.yaml
pip install gym==0.21.0
```
The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.
The `environment.yaml` file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in `environment.yaml`.
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
@@ -118,16 +120,23 @@ We recommend using default hyperparameters for single-task online RL, including
## Citation
If you find our work useful, please consider citing the paper as follows:
If you find our work useful, please consider citing our paper as follows:
```
@misc{hansen2023tdmpc2,
title={TD-MPC2: Scalable, Robust World Models for Continuous Control},
author={Nicklas Hansen and Hao Su and Xiaolong Wang},
year={2023},
eprint={2310.16828},
archivePrefix={arXiv},
primaryClass={cs.LG}
@inproceedings{hansen2024tdmpc2,
title={TD-MPC2: Scalable, Robust World Models for Continuous Control},
author={Nicklas Hansen and Hao Su and Xiaolong Wang},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
```
as well as the original TD-MPC paper:
```
@inproceedings{hansen2022tdmpc,
title={Temporal Difference Learning for Model Predictive Control},
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
booktitle={International Conference on Machine Learning (ICML)},
year={2022}
}
```

View File

@@ -1,10 +1,18 @@
##########################################
# Dockerfile for TD-MPC2 #
# TD-MPC2 Anonymous Authors, 2023 (c) #
# Nicklas Hansen, 2023 (c) #
# https://www.tdmpc2.com #
# -------------------------------------- #
# Instructions: #
# docker build . -t <user>/tdmpc2:0.1.0 #
# docker push <user>/tdmpc2:0.1.0 #
# Build instructions: #
# docker build . -t <user>/tdmpc2:1.0.0 #
# docker push <user>/tdmpc2:1.0.0 #
# -------------------------------------- #
# Run: #
# docker run -i \ #
# -v <path>/<to>/tdmpc2:/tdmpc2 \ #
# --gpus all \ #
# -t <user>/tdmpc2:1.0.0 \ #
# /bin/bash #
##########################################
# base image
@@ -36,19 +44,15 @@ SHELL ["/bin/bash", "-c"]
# conda environment
COPY nvidia_icd.json /usr/share/vulkan/icd.d/nvidia_icd.json
COPY environment.yaml /root
RUN conda env update -n base -f /root/environment.yaml && \
RUN conda update conda && \
conda env update -n base -f /root/environment.yaml && \
rm /root/environment.yaml && \
cd /root && \
python -m mani_skill2.utils.download_asset all -y && \
conda clean -ya && \
pip cache purge
# environment variables
# mujoco 2.1.0
ENV MUJOCO_GL egl
ENV MS2_ASSET_DIR /root/data
ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
# mujoco (required for metaworld)
RUN mkdir -p /root/.mujoco && \
wget https://www.tdmpc2.com/files/mjkey.txt && \
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \
@@ -56,4 +60,23 @@ RUN mkdir -p /root/.mujoco && \
rm mujoco210-linux-x86_64.tar.gz && \
mv mujoco210 /root/.mujoco/mujoco210 && \
mv mjkey.txt /root/.mujoco/mjkey.txt && \
find /root/.mujoco -uid 421709 -exec chown root:root {} \; && \
python -c "import mujoco_py"
# gym
RUN pip install gym==0.21.0
# metaworld
RUN pip install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
# maniskill2
ENV MS2_ASSET_DIR /root/data
RUN pip install mani-skill2==0.4.1 && \
cd /root && \
python -m mani_skill2.utils.download_asset all -y
# myosuite (conflicts with meta-world / mani-skill2)
# RUN pip install myosuite
# success!
RUN echo "Successfully built TD-MPC2 Docker image!"

View File

@@ -5,49 +5,62 @@ channels:
- conda-forge
- defaults
dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7
- glew
- glib
- pip==21
- glew=2.1.0
- glib=2.68.4
- pip=21.0
- python=3.9.0
- pytorch>=2.2.2
- torchvision>=0.16.2
- pip:
- absl-py
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- transforms3d
- joblib
- opencv-python
- opencv-contrib-python
- filelock
- sapien==2.2.1
- mani-skill2==0.4.1
- trimesh
- open3d
- setuptools==65.5.0
- absl-py==2.0.0
- "cython<3"
- dm-control==1.0.8
- ffmpeg==1.4
- glfw==2.6.4
- hydra-core==1.3.2
- hydra-submitit-launcher==1.2.0
- imageio==2.33.1
- imageio-ffmpeg==0.4.9
- kornia==0.7.1
- moviepy==1.0.3
- mujoco==2.3.1
- mujoco-py==2.1.2.14
- dm-control
- pillow
- pyquaternion
- git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
# - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed
- tensordict-nightly
- torchrl-nightly
- wandb
- numpy==1.23.5
- omegaconf==2.3.0
- open3d==0.18.0
- opencv-contrib-python==4.9.0.80
- opencv-python==4.9.0.80
- pandas==2.1.4
- sapien==2.2.1
- submitit==1.5.1
- setuptools==65.5.0
- patchelf==0.17.2.1
- protobuf==4.25.2
- pillow==10.2.0
- pyquaternion==0.9.9
- tensordict-nightly==2024.3.26
- termcolor==2.4.0
- torchrl-nightly==2024.3.26
- transforms3d==0.4.1
- trimesh==4.0.9
- tqdm==4.66.1
- wandb==0.16.2
- wheel==0.38.0
####################
# Gym:
# (unmaintained but required for maniskill2/meta-world/myosuite)
# - gym==0.21.0
####################
# ManiSkill2:
# (requires gym==0.21.0 which occasionally breaks)
# - mani-skill2==0.4.1
####################
# Meta-World:
# (requires gym==0.21.0 which occasionally breaks)
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
####################
# MyoSuite:
# (requires gym==0.13 which conflicts with meta-world / mani-skill2)
# - myosuite
####################

View File

@@ -1,39 +0,0 @@
name: tdmpc2
channels:
- pytorch-nightly
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7
- glew
- glib
- pip==21
- pip:
- absl-py
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- setuptools==65.5.0
- "cython<3"
- dm-control
- pillow
- tensordict-nightly
- torchrl-nightly
- wandb

View File

@@ -1,8 +1,7 @@
import torch
from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from common.samplers import SliceSampler
from torchrl.data.replay_buffers.samplers import SliceSampler
class Buffer():
@@ -20,6 +19,7 @@ class Buffer():
end_key=None,
traj_key='episode',
truncated_key=None,
strict_length=True,
)
self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0

View File

@@ -227,8 +227,10 @@ class Logger:
xkey = "step"
elif category == "pretrain":
xkey = "iteration"
_d = dict()
for k, v in d.items():
self._wandb.log({category + "/" + k: v}, step=d[xkey])
_d[category + "/" + k] = v
self._wandb.log(_d, step=d[xkey])
if category == "eval" and self._save_csv:
keys = ["step", "episode_reward"]
self._eval.append(np.array([d[keys[0]], d[keys[1]]]))

View File

@@ -1,351 +0,0 @@
from __future__ import annotations
from copy import copy
from typing import Tuple
import torch
from tensordict.utils import NestedKey
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.samplers import Sampler
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
# This copy will live here until it has been included in a few torchrl stable releases
class SliceSampler(Sampler):
"""Samples slices of data along the first dimension, given start and stop signals.
This class samples sub-trajectories with replacement. For a version without
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
with ``slice_len``.
slice_len (int): the length of the slices to be sampled. The batch-size
must be greater or equal to the ``slice_len`` argument and divisible
by it. Exclusive with ``num_slices``.
end_key (NestedKey, optional): the key indicating the end of a
trajectory (or episode). Defaults to ``("next", "done")``.
traj_key (NestedKey, optional): the key indicating the trajectories.
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
trajectory breaks. Defaults to ``("next", "truncated")``.
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
instances (otherwise the truncated key is returned in the info dictionary
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
attempt to find the ``traj_key`` entry in the storage. If it cannot be
found, the ``end_key`` will be used to reconstruct the episodes.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
... storage=LazyMemmapStorage(1_000_000),
... sampler=SliceSampler(cache_values=True, num_slices=10),
... batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
... {
... "episode": episode,
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
... "act": torch.randn((20,)).expand(1000, 20),
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
... }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
most of TorchRL's datasets:
Examples:
>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
... batch = batch.reshape(num_slices, -1)
... break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
[14],
[ 8],
[10],
[13],
[ 4],
[ 2],
[ 3],
[22],
[ 8]])
"""
def __init__(
self,
*,
num_slices: int = None,
slice_len: int = None,
end_key: NestedKey | None = None,
traj_key: NestedKey | None = None,
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
) -> object:
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
traj_key = "episode"
if not ((num_slices is None) ^ (slice_len is None)):
raise TypeError(
"Either num_slices or slice_len must be not None, and not both. "
f"Got num_slices={num_slices} and slice_len={slice_len}."
)
self.num_slices = num_slices
self.slice_len = slice_len
self.end_key = end_key
self.traj_key = traj_key
self.truncated_key = truncated_key
self.cache_values = cache_values
self._fetch_traj = True
self._uses_data_prefix = False
self.strict_length = strict_length
self._cache = {}
@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
# stop_idx = stop_idx.cumsum(0) - 1
# even slower
# t = trajectory.unsqueeze(0)
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
# stop_idx = torch.conv1d(t, w).nonzero()
# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
if end.ndim != 1:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
)
stop_idx = end.view(-1).nonzero().view(-1)
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
lengths = stop_idx - start_idx + 1
return start_idx, stop_idx, lengths
def _tensor_slices_from_startend(self, seq_length, start):
if isinstance(seq_length, int):
return (
torch.arange(
seq_length, device=start.device, dtype=start.dtype
).unsqueeze(0)
+ start.unsqueeze(1)
).view(-1)
else:
# when padding is needed
return torch.cat(
[
_start
+ torch.arange(_seq_len, device=start.device, dtype=start.dtype)
for _start, _seq_len in zip(start, seq_length)
]
)
def _get_stop_and_length(self, storage, fallback=True):
if self.cache_values and "stop-and-length" in self._cache:
return self._cache.get("stop-and-length")
if self._fetch_traj:
# We first try with the traj_key
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
trajectory = storage._storage.get(self._used_traj_key)
except KeyError:
trajectory = storage._storage.get(("_data", self.traj_key))
# cache that value for future use
self._used_traj_key = ("_data", self.traj_key)
self._uses_data_prefix = (
isinstance(self._used_traj_key, tuple)
and self._used_traj_key[0] == "_data"
)
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = False
return self._get_stop_and_length(storage, fallback=False)
raise
else:
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
done = storage._storage.get(self._used_end_key)
except KeyError:
done = storage._storage.get(("_data", self.end_key))
# cache that value for future use
self._used_end_key = ("_data", self.end_key)
self._uses_data_prefix = (
isinstance(self._used_end_key, tuple)
and self._used_end_key[0] == "_data"
)
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = True
return self._get_stop_and_length(storage, fallback=False)
raise
def _adjusted_batch_size(self, batch_size):
if self.num_slices is not None:
if batch_size % self.num_slices != 0:
raise RuntimeError(
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
)
seq_length = batch_size // self.num_slices
num_slices = self.num_slices
else:
if batch_size % self.slice_len != 0:
raise RuntimeError(
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
)
seq_length = self.slice_len
num_slices = batch_size // self.slice_len
return seq_length, num_slices
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if not isinstance(storage, TensorStorage):
raise RuntimeError(
f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
)
# pick up as many trajs as we need
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
return self._sample_slices(lengths, start_idx, seq_length, num_slices)
def _sample_slices(
self, lengths, start_idx, seq_length, num_slices, traj_idx=None
) -> Tuple[torch.Tensor, dict]:
if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for. "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
# make seq_length a tensor with values clamped by lengths
seq_length = lengths.clamp_max(seq_length)
if traj_idx is None:
traj_idx = torch.randint(
lengths.shape[0], (num_slices,), device=lengths.device
)
else:
num_slices = traj_idx.shape[0]
relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
* (lengths[traj_idx] - seq_length + 1)
)
.floor()
.to(start_idx.dtype)
)
starts = start_idx[traj_idx] + relative_starts
index = self._tensor_slices_from_startend(seq_length, starts)
if self.truncated_key is not None:
truncated_key = self.truncated_key
truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
return index.to(torch.long), {truncated_key: truncated}
return index.to(torch.long), {}
@property
def _used_traj_key(self):
return self.__dict__.get("__used_traj_key", self.traj_key)
@_used_traj_key.setter
def _used_traj_key(self, value):
self.__dict__["__used_traj_key"] = value
@property
def _used_end_key(self):
return self.__dict__.get("__used_end_key", self.end_key)
@_used_end_key.setter
def _used_end_key(self, value):
self.__dict__["__used_end_key"] = value
def _empty(self):
pass
def dumps(self, path):
# no op - cache does not need to be saved
...
def loads(self, path):
# no op
...
def __getstate__(self):
state = copy(self.__dict__)
state["_cache"] = {}
return state

View File

@@ -38,6 +38,7 @@ horizon: 3
min_std: 0.05
max_std: 2
temperature: 0.5
uncertainty_coef: 0
# actor
log_std_min: -10

View File

@@ -44,7 +44,7 @@ def evaluate(cfg: dict):
cfg = parse_cfg(cfg)
set_seed(cfg.seed)
print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold']))
print(colored(f'Model size: {cfg.model_size}', 'blue', attrs=['bold']))
print(colored(f'Model size: {cfg.get("model_size", "default")}', 'blue', attrs=['bold']))
print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold']))
if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint):
print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold']))

View File

@@ -90,6 +90,14 @@ class TDMPC2:
else:
a = self.model.pi(z, task)[int(not eval_mode)][0]
return a.cpu()
@torch.no_grad()
def _estimate_uncertainty(self, z, task):
"""Estimates epistemic uncertainty, normalized by predicted value."""
if self.cfg.uncertainty_coef == 0:
return 0
qs = math.two_hot_inv(self.model.Q(z, self.model.pi(z, task)[1], task, return_type='all'), self.cfg)
return qs.mean() * qs.std(0) * self.cfg.uncertainty_coef
@torch.no_grad()
def _estimate_value(self, z, actions, task):
@@ -98,9 +106,10 @@ class TDMPC2:
for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task)
G += discount * reward
G += discount * (reward - self._estimate_uncertainty(z, task))
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
terminal_value = self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
return G + discount * (terminal_value - self._estimate_uncertainty(z, task))
@torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None):

View File

@@ -7,8 +7,8 @@ class Trainer:
self.agent = agent
self.buffer = buffer
self.logger = logger
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
print('Architecture:', self.agent.model)
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
def eval(self):
"""Evaluate a TD-MPC2 agent."""