15 Commits

Author SHA1 Message Date
Nicklas Hansen
b1afbccb05 update pinned torchrl version 2024-07-02 10:07:33 -07:00
Nicklas Hansen
c218c0ff1b update order of arch,params print 2024-02-27 13:18:45 -08:00
Nicklas Hansen
d3bff48d58 Merge branch 'distributed' of github.com:nicklashansen/tdmpc2 into distributed 2024-02-27 13:18:18 -08:00
Nicklas Hansen
c16f2557bb support distributed training 2024-02-27 13:18:14 -08:00
Nicklas Hansen
de87519c60 support distributed training 2024-02-27 13:17:55 -08:00
Nicklas Hansen
57158282b4 Merge branch 'main' of github.com:nicklashansen/tdmpc2 into main 2024-02-02 15:56:56 -08:00
Nicklas Hansen
718966c28d reduce # wandb calls 2024-02-02 15:56:54 -08:00
Nicklas Hansen
01cdf0f799 Update README.md 2024-01-24 21:43:20 -08:00
Nicklas Hansen
02b18a48b1 update dockerfile 2024-01-22 17:37:31 -08:00
Nicklas Hansen
e8f1ed6785 update dockerfile + pin all versions 2024-01-21 21:21:44 -08:00
Nicklas Hansen
8b6fe61bed minor fix in print 2024-01-11 18:19:08 -08:00
Nicklas Hansen
aa9c6f33f5 migrate to slicebuffer from torchrl-nightly 2024-01-10 19:53:30 -08:00
Nicklas Hansen
20f4064dfa Merge branch 'distributed' of github.com:nicklashansen/tdmpc2 into distributed 2024-01-08 10:51:00 -08:00
Nicklas Hansen
c6d1bd85bf support distributed training 2024-01-08 10:50:43 -08:00
Nicklas Hansen
33555b5982 support distributed training 2024-01-07 11:52:53 -08:00
16 changed files with 242 additions and 508 deletions

View File

@@ -2,13 +2,13 @@
Official implementation of Official implementation of
[TD-MPC2: Scalable, Robust World Models for Continuous Control](https://nicklashansen.github.io/td-mpc2) by [TD-MPC2: Scalable, Robust World Models for Continuous Control](https://www.tdmpc2.com) by
[Nicklas Hansen](https://nicklashansen.github.io/), [Hao Su](https://cseweb.ucsd.edu/~haosu/)\*, [Xiaolong Wang](https://xiaolonw.github.io/)\* (UC San Diego)</br> [Nicklas Hansen](https://nicklashansen.github.io), [Hao Su](https://cseweb.ucsd.edu/~haosu)\*, [Xiaolong Wang](https://xiaolonw.github.io)\* (UC San Diego)</br>
<img src="assets/0.gif" width="12.5%"><img src="assets/1.gif" width="12.5%"><img src="assets/2.gif" width="12.5%"><img src="assets/3.gif" width="12.5%"><img src="assets/4.gif" width="12.5%"><img src="assets/5.gif" width="12.5%"><img src="assets/6.gif" width="12.5%"><img src="assets/7.gif" width="12.5%"></br> <img src="assets/0.gif" width="12.5%"><img src="assets/1.gif" width="12.5%"><img src="assets/2.gif" width="12.5%"><img src="assets/3.gif" width="12.5%"><img src="assets/4.gif" width="12.5%"><img src="assets/5.gif" width="12.5%"><img src="assets/6.gif" width="12.5%"><img src="assets/7.gif" width="12.5%"></br>
[[Website]](https://nicklashansen.github.io/td-mpc2) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://nicklashansen.github.io/td-mpc2/models) [[Dataset]](https://nicklashansen.github.io/td-mpc2/dataset) [[Website]](https://www.tdmpc2.com) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://www.tdmpc2.com/models) [[Dataset]](https://www.tdmpc2.com/dataset)
---- ----
@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
<img src="assets/8.png" width="100%" style="max-width: 640px"><br/> <img src="assets/8.png" width="100%" style="max-width: 640px"><br/>
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL. This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://www.tdmpc2.com/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://www.tdmpc2.com/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
---- ----
@@ -29,17 +29,19 @@ You will need a machine with a GPU and at least 12 GB of RAM for single-task onl
We provide a `Dockerfile` for easy installation. You can build the docker image by running We provide a `Dockerfile` for easy installation. You can build the docker image by running
``` ```
cd docker && docker build . -t <user>/tdmpc2:0.1.0 cd docker && docker build . -t <user>/tdmpc2:1.0.0
``` ```
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands: This docker image contains all dependencies needed for running DMControl, Meta-World, and ManiSkill2 experiments.
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running the following command:
``` ```
conda env create -f docker/environment.yaml conda env create -f docker/environment.yaml
conda env create -f docker/environment_minimal.yaml pip install gym==0.21.0
``` ```
The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks. The `environment.yaml` file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in `environment.yaml`.
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
@@ -118,16 +120,23 @@ We recommend using default hyperparameters for single-task online RL, including
## Citation ## Citation
If you find our work useful, please consider citing the paper as follows: If you find our work useful, please consider citing our paper as follows:
``` ```
@misc{hansen2023tdmpc2, @inproceedings{hansen2024tdmpc2,
title={TD-MPC2: Scalable, Robust World Models for Continuous Control}, title={TD-MPC2: Scalable, Robust World Models for Continuous Control},
author={Nicklas Hansen and Hao Su and Xiaolong Wang}, author={Nicklas Hansen and Hao Su and Xiaolong Wang},
year={2023}, booktitle={International Conference on Learning Representations (ICLR)},
eprint={2310.16828}, year={2024}
archivePrefix={arXiv}, }
primaryClass={cs.LG} ```
as well as the original TD-MPC paper:
```
@inproceedings{hansen2022tdmpc,
title={Temporal Difference Learning for Model Predictive Control},
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
booktitle={International Conference on Machine Learning (ICML)},
year={2022}
} }
``` ```

View File

@@ -1,10 +1,18 @@
########################################## ##########################################
# Dockerfile for TD-MPC2 # # Dockerfile for TD-MPC2 #
# TD-MPC2 Anonymous Authors, 2023 (c) # # Nicklas Hansen, 2023 (c) #
# https://www.tdmpc2.com #
# -------------------------------------- # # -------------------------------------- #
# Instructions: # # Build instructions: #
# docker build . -t <user>/tdmpc2:0.1.0 # # docker build . -t <user>/tdmpc2:1.0.0 #
# docker push <user>/tdmpc2:0.1.0 # # docker push <user>/tdmpc2:1.0.0 #
# -------------------------------------- #
# Run: #
# docker run -i \ #
# -v <path>/<to>/tdmpc2:/tdmpc2 \ #
# --gpus all \ #
# -t <user>/tdmpc2:1.0.0 \ #
# /bin/bash #
########################################## ##########################################
# base image # base image
@@ -36,19 +44,15 @@ SHELL ["/bin/bash", "-c"]
# conda environment # conda environment
COPY nvidia_icd.json /usr/share/vulkan/icd.d/nvidia_icd.json COPY nvidia_icd.json /usr/share/vulkan/icd.d/nvidia_icd.json
COPY environment.yaml /root COPY environment.yaml /root
RUN conda env update -n base -f /root/environment.yaml && \ RUN conda update conda && \
conda env update -n base -f /root/environment.yaml && \
rm /root/environment.yaml && \ rm /root/environment.yaml && \
cd /root && \
python -m mani_skill2.utils.download_asset all -y && \
conda clean -ya && \ conda clean -ya && \
pip cache purge pip cache purge
# environment variables # mujoco 2.1.0
ENV MUJOCO_GL egl ENV MUJOCO_GL egl
ENV MS2_ASSET_DIR /root/data
ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH} ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
# mujoco (required for metaworld)
RUN mkdir -p /root/.mujoco && \ RUN mkdir -p /root/.mujoco && \
wget https://www.tdmpc2.com/files/mjkey.txt && \ wget https://www.tdmpc2.com/files/mjkey.txt && \
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \ wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \
@@ -56,4 +60,23 @@ RUN mkdir -p /root/.mujoco && \
rm mujoco210-linux-x86_64.tar.gz && \ rm mujoco210-linux-x86_64.tar.gz && \
mv mujoco210 /root/.mujoco/mujoco210 && \ mv mujoco210 /root/.mujoco/mujoco210 && \
mv mjkey.txt /root/.mujoco/mjkey.txt && \ mv mjkey.txt /root/.mujoco/mjkey.txt && \
find /root/.mujoco -uid 421709 -exec chown root:root {} \; && \
python -c "import mujoco_py" python -c "import mujoco_py"
# gym
RUN pip install gym==0.21.0
# metaworld
RUN pip install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
# maniskill2
ENV MS2_ASSET_DIR /root/data
RUN pip install mani-skill2==0.4.1 && \
cd /root && \
python -m mani_skill2.utils.download_asset all -y
# myosuite (conflicts with meta-world / mani-skill2)
# RUN pip install myosuite
# success!
RUN echo "Successfully built TD-MPC2 Docker image!"

View File

@@ -5,49 +5,62 @@ channels:
- conda-forge - conda-forge
- defaults - defaults
dependencies: dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7 - cudatoolkit=11.7
- glew - glew=2.1.0
- glib - glib=2.68.4
- pip==21 - pip=21.0
- python=3.9.0
- pytorch>=2.2.2
- torchvision>=0.16.2
- pip: - pip:
- absl-py - absl-py==2.0.0
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- transforms3d
- joblib
- opencv-python
- opencv-contrib-python
- filelock
- sapien==2.2.1
- mani-skill2==0.4.1
- trimesh
- open3d
- setuptools==65.5.0
- "cython<3" - "cython<3"
- dm-control==1.0.8
- ffmpeg==1.4
- glfw==2.6.4
- hydra-core==1.3.2
- hydra-submitit-launcher==1.2.0
- imageio==2.33.1
- imageio-ffmpeg==0.4.9
- kornia==0.7.1
- moviepy==1.0.3
- mujoco==2.3.1 - mujoco==2.3.1
- mujoco-py==2.1.2.14 - mujoco-py==2.1.2.14
- dm-control - numpy==1.23.5
- pillow - omegaconf==2.3.0
- pyquaternion - open3d==0.18.0
- git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb - opencv-contrib-python==4.9.0.80
# - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed - opencv-python==4.9.0.80
- tensordict-nightly - pandas==2.1.4
- torchrl-nightly - sapien==2.2.1
- wandb - submitit==1.5.1
- setuptools==65.5.0
- patchelf==0.17.2.1
- protobuf==4.25.2
- pillow==10.2.0
- pyquaternion==0.9.9
- tensordict-nightly==2024.3.26
- termcolor==2.4.0
- torchrl-nightly==2024.3.26
- transforms3d==0.4.1
- trimesh==4.0.9
- tqdm==4.66.1
- wandb==0.16.2
- wheel==0.38.0
####################
# Gym:
# (unmaintained but required for maniskill2/meta-world/myosuite)
# - gym==0.21.0
####################
# ManiSkill2:
# (requires gym==0.21.0 which occasionally breaks)
# - mani-skill2==0.4.1
####################
# Meta-World:
# (requires gym==0.21.0 which occasionally breaks)
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
####################
# MyoSuite:
# (requires gym==0.13 which conflicts with meta-world / mani-skill2)
# - myosuite
####################

View File

@@ -1,39 +0,0 @@
name: tdmpc2
channels:
- pytorch-nightly
- nvidia
- conda-forge
- defaults
dependencies:
- python=3.9.0
- pytorch
- torchvision
- cudatoolkit=11.7
- glew
- glib
- pip==21
- pip:
- absl-py
- glfw
- kornia
- termcolor
- gym==0.21.0
- moviepy
- ffmpeg
- imageio
- imageio-ffmpeg
- omegaconf
- hydra-core
- hydra-submitit-launcher
- submitit
- pandas
- patchelf
- protobuf
- tqdm
- setuptools==65.5.0
- "cython<3"
- dm-control
- pillow
- tensordict-nightly
- torchrl-nightly
- wandb

View File

@@ -1,8 +1,7 @@
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
from torchrl.data.replay_buffers.samplers import SliceSampler
from common.samplers import SliceSampler
class Buffer(): class Buffer():
@@ -13,16 +12,18 @@ class Buffer():
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._device = torch.device('cuda') self._device = torch.device(self.cfg.rank)
self._capacity = min(cfg.buffer_size, cfg.steps) self._capacity = min(cfg.buffer_size, cfg.steps)
self._sampler = SliceSampler( self._sampler = SliceSampler(
num_slices=self.cfg.batch_size, num_slices=self.cfg.batch_size,
end_key=None, end_key=None,
traj_key='episode', traj_key='episode',
truncated_key=None, truncated_key=None,
strict_length=True,
) )
self._batch_size = cfg.batch_size * (cfg.horizon+1) self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0 self._num_eps = 0
self._num_transitions = 0
@property @property
def capacity(self): def capacity(self):
@@ -33,6 +34,11 @@ class Buffer():
def num_eps(self): def num_eps(self):
"""Return the number of episodes in the buffer.""" """Return the number of episodes in the buffer."""
return self._num_eps return self._num_eps
@property
def num_transitions(self):
"""Return the number of transitions in the buffer."""
return self._num_transitions
def _reserve_buffer(self, storage): def _reserve_buffer(self, storage):
""" """
@@ -48,7 +54,11 @@ class Buffer():
def _init(self, tds): def _init(self, tds):
"""Initialize the replay buffer. Use the first episode to estimate storage requirements.""" """Initialize the replay buffer. Use the first episode to estimate storage requirements."""
print(f'Buffer capacity: {self._capacity:,}') if self.cfg.rank == 0:
if self.cfg.world_size > 1:
print(f'Buffer capacity per process: {self._capacity:,}')
else:
print(f'Buffer capacity: {self._capacity:,}')
mem_free, _ = torch.cuda.mem_get_info() mem_free, _ = torch.cuda.mem_get_info()
bytes_per_step = sum([ bytes_per_step = sum([
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \ (v.numel()*v.element_size() if not isinstance(v, TensorDict) \
@@ -56,10 +66,15 @@ class Buffer():
for v in tds.values() for v in tds.values()
]) / len(tds) ]) / len(tds)
total_bytes = bytes_per_step*self._capacity total_bytes = bytes_per_step*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB') if self.cfg.rank == 0:
if self.cfg.world_size > 1:
print(f'Storage required per process: {total_bytes/1e9:.2f} GB')
else:
print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory # Heuristic: decide whether to use CUDA or CPU memory
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' storage_device = self.cfg.rank if 2.5*total_bytes < mem_free else 'cpu'
print(f'Using {storage_device.upper()} memory for storage.') if self.cfg.rank == 0:
print(f'Using {storage_device.upper()} memory for storage.')
return self._reserve_buffer( return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device(storage_device)) LazyTensorStorage(self._capacity, device=torch.device(storage_device))
) )
@@ -88,6 +103,7 @@ class Buffer():
self._buffer = self._init(td) self._buffer = self._init(td)
self._buffer.extend(td) self._buffer.extend(td)
self._num_eps += 1 self._num_eps += 1
self._num_transitions += len(td)
return self._num_eps return self._num_eps
def sample(self): def sample(self):

View File

@@ -113,11 +113,13 @@ class Logger:
self._group = cfg_to_group(cfg) self._group = cfg_to_group(cfg)
self._seed = cfg.seed self._seed = cfg.seed
self._eval = [] self._eval = []
print_run(cfg) if cfg.rank == 0:
print_run(cfg)
self.project = cfg.get("wandb_project", "none") self.project = cfg.get("wandb_project", "none")
self.entity = cfg.get("wandb_entity", "none") self.entity = cfg.get("wandb_entity", "none")
if cfg.disable_wandb or self.project == "none" or self.entity == "none": if cfg.rank == 0 or cfg.disable_wandb or self.project == "none" or self.entity == "none":
print(colored("Wandb disabled.", "blue", attrs=["bold"])) if cfg.rank == 0:
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
cfg.save_agent = False cfg.save_agent = False
cfg.save_video = False cfg.save_video = False
self._wandb = None self._wandb = None
@@ -227,8 +229,10 @@ class Logger:
xkey = "step" xkey = "step"
elif category == "pretrain": elif category == "pretrain":
xkey = "iteration" xkey = "iteration"
_d = dict()
for k, v in d.items(): for k, v in d.items():
self._wandb.log({category + "/" + k: v}, step=d[xkey]) _d[category + "/" + k] = v
self._wandb.log(_d, step=d[xkey])
if category == "eval" and self._save_csv: if category == "eval" and self._save_csv:
keys = ["step", "episode_reward"] keys = ["step", "episode_reward"]
self._eval.append(np.array([d[keys[0]], d[keys[1]]])) self._eval.append(np.array([d[keys[0]], d[keys[1]]]))

View File

@@ -1,351 +0,0 @@
from __future__ import annotations
from copy import copy
from typing import Tuple
import torch
from tensordict.utils import NestedKey
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
from torchrl.data.replay_buffers.samplers import Sampler
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
# This copy will live here until it has been included in a few torchrl stable releases
class SliceSampler(Sampler):
"""Samples slices of data along the first dimension, given start and stop signals.
This class samples sub-trajectories with replacement. For a version without
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
Keyword Args:
num_slices (int): the number of slices to be sampled. The batch-size
must be greater or equal to the ``num_slices`` argument. Exclusive
with ``slice_len``.
slice_len (int): the length of the slices to be sampled. The batch-size
must be greater or equal to the ``slice_len`` argument and divisible
by it. Exclusive with ``num_slices``.
end_key (NestedKey, optional): the key indicating the end of a
trajectory (or episode). Defaults to ``("next", "done")``.
traj_key (NestedKey, optional): the key indicating the trajectories.
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
cache_values (bool, optional): to be used with static datasets.
Will cache the start and end signal of the trajectory.
truncated_key (NestedKey, optional): If not ``None``, this argument
indicates where a truncated signal should be written in the output
data. This is used to indicate to value estimators where the provided
trajectory breaks. Defaults to ``("next", "truncated")``.
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
instances (otherwise the truncated key is returned in the info dictionary
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
strict_length (bool, optional): if ``False``, trajectories of length
shorter than `slice_len` (or `batch_size // num_slices`) will be
allowed to appear in the batch.
Be mindful that this can result in effective `batch_size` shorter
than the one asked for! Trajectories can be split using
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
.. note:: To recover the trajectory splits in the storage,
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
attempt to find the ``traj_key`` entry in the storage. If it cannot be
found, the ``end_key`` will be used to reconstruct the episodes.
Examples:
>>> import torch
>>> from tensordict import TensorDict
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
>>> torch.manual_seed(0)
>>> rb = TensorDictReplayBuffer(
... storage=LazyMemmapStorage(1_000_000),
... sampler=SliceSampler(cache_values=True, num_slices=10),
... batch_size=320,
... )
>>> episode = torch.zeros(1000, dtype=torch.int)
>>> episode[:300] = 1
>>> episode[300:550] = 2
>>> episode[550:700] = 3
>>> episode[700:] = 4
>>> data = TensorDict(
... {
... "episode": episode,
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
... "act": torch.randn((20,)).expand(1000, 20),
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
... }, [1000]
... )
>>> rb.extend(data)
>>> sample = rb.sample()
>>> print("sample:", sample)
>>> print("episodes", sample.get("episode").unique())
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
most of TorchRL's datasets:
Examples:
>>> import torch
>>>
>>> from torchrl.data.datasets import RobosetExperienceReplay
>>> from torchrl.data import SliceSampler
>>>
>>> torch.manual_seed(0)
>>> num_slices = 10
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
>>> for batch in data:
... batch = batch.reshape(num_slices, -1)
... break
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
check that each batch only has one episode: tensor([[19],
[14],
[ 8],
[10],
[13],
[ 4],
[ 2],
[ 3],
[22],
[ 8]])
"""
def __init__(
self,
*,
num_slices: int = None,
slice_len: int = None,
end_key: NestedKey | None = None,
traj_key: NestedKey | None = None,
cache_values: bool = False,
truncated_key: NestedKey | None = ("next", "truncated"),
strict_length: bool = True,
) -> object:
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
traj_key = "episode"
if not ((num_slices is None) ^ (slice_len is None)):
raise TypeError(
"Either num_slices or slice_len must be not None, and not both. "
f"Got num_slices={num_slices} and slice_len={slice_len}."
)
self.num_slices = num_slices
self.slice_len = slice_len
self.end_key = end_key
self.traj_key = traj_key
self.truncated_key = truncated_key
self.cache_values = cache_values
self._fetch_traj = True
self._uses_data_prefix = False
self.strict_length = strict_length
self._cache = {}
@staticmethod
def _find_start_stop_traj(*, trajectory=None, end=None):
if trajectory is not None:
# slower
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
# stop_idx = stop_idx.cumsum(0) - 1
# even slower
# t = trajectory.unsqueeze(0)
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
# stop_idx = torch.conv1d(t, w).nonzero()
# faster
end = trajectory[:-1] != trajectory[1:]
end = torch.cat([end, torch.ones_like(end[:1])], 0)
else:
end = torch.index_fill(
end,
index=torch.tensor(-1, device=end.device, dtype=torch.long),
dim=0,
value=1,
)
if end.ndim != 1:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
)
stop_idx = end.view(-1).nonzero().view(-1)
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
lengths = stop_idx - start_idx + 1
return start_idx, stop_idx, lengths
def _tensor_slices_from_startend(self, seq_length, start):
if isinstance(seq_length, int):
return (
torch.arange(
seq_length, device=start.device, dtype=start.dtype
).unsqueeze(0)
+ start.unsqueeze(1)
).view(-1)
else:
# when padding is needed
return torch.cat(
[
_start
+ torch.arange(_seq_len, device=start.device, dtype=start.dtype)
for _start, _seq_len in zip(start, seq_length)
]
)
def _get_stop_and_length(self, storage, fallback=True):
if self.cache_values and "stop-and-length" in self._cache:
return self._cache.get("stop-and-length")
if self._fetch_traj:
# We first try with the traj_key
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
trajectory = storage._storage.get(self._used_traj_key)
except KeyError:
trajectory = storage._storage.get(("_data", self.traj_key))
# cache that value for future use
self._used_traj_key = ("_data", self.traj_key)
self._uses_data_prefix = (
isinstance(self._used_traj_key, tuple)
and self._used_traj_key[0] == "_data"
)
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = False
return self._get_stop_and_length(storage, fallback=False)
raise
else:
try:
# In some cases, the storage hides the data behind "_data".
# In the future, this may be deprecated, and we don't want to mess
# with the keys provided by the user so we fall back on a proxy to
# the traj key.
try:
done = storage._storage.get(self._used_end_key)
except KeyError:
done = storage._storage.get(("_data", self.end_key))
# cache that value for future use
self._used_end_key = ("_data", self.end_key)
self._uses_data_prefix = (
isinstance(self._used_end_key, tuple)
and self._used_end_key[0] == "_data"
)
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
if self.cache_values:
self._cache["stop-and-length"] = vals
return vals
except KeyError:
if fallback:
self._fetch_traj = True
return self._get_stop_and_length(storage, fallback=False)
raise
def _adjusted_batch_size(self, batch_size):
if self.num_slices is not None:
if batch_size % self.num_slices != 0:
raise RuntimeError(
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
)
seq_length = batch_size // self.num_slices
num_slices = self.num_slices
else:
if batch_size % self.slice_len != 0:
raise RuntimeError(
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
)
seq_length = self.slice_len
num_slices = batch_size // self.slice_len
return seq_length, num_slices
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if not isinstance(storage, TensorStorage):
raise RuntimeError(
f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
)
# pick up as many trajs as we need
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
return self._sample_slices(lengths, start_idx, seq_length, num_slices)
def _sample_slices(
self, lengths, start_idx, seq_length, num_slices, traj_idx=None
) -> Tuple[torch.Tensor, dict]:
if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for. "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
# make seq_length a tensor with values clamped by lengths
seq_length = lengths.clamp_max(seq_length)
if traj_idx is None:
traj_idx = torch.randint(
lengths.shape[0], (num_slices,), device=lengths.device
)
else:
num_slices = traj_idx.shape[0]
relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
* (lengths[traj_idx] - seq_length + 1)
)
.floor()
.to(start_idx.dtype)
)
starts = start_idx[traj_idx] + relative_starts
index = self._tensor_slices_from_startend(seq_length, starts)
if self.truncated_key is not None:
truncated_key = self.truncated_key
truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
if isinstance(seq_length, int):
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
return index.to(torch.long), {truncated_key: truncated}
return index.to(torch.long), {}
@property
def _used_traj_key(self):
return self.__dict__.get("__used_traj_key", self.traj_key)
@_used_traj_key.setter
def _used_traj_key(self, value):
self.__dict__["__used_traj_key"] = value
@property
def _used_end_key(self):
return self.__dict__.get("__used_end_key", self.end_key)
@_used_end_key.setter
def _used_end_key(self, value):
self.__dict__["__used_end_key"] = value
def _empty(self):
pass
def dumps(self, path):
# no op - cache does not need to be saved
...
def loads(self, path):
# no op
...
def __getstate__(self):
state = copy(self.__dict__)
state["_cache"] = {}
return state

View File

@@ -6,8 +6,8 @@ class RunningScale:
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda')) self._value = torch.ones(1, dtype=torch.float32, device=torch.device(cfg.rank))
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')) self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device(cfg.rank))
def state_dict(self): def state_dict(self):
return dict(value=self._value, percentiles=self._percentiles) return dict(value=self._value, percentiles=self._percentiles)

View File

@@ -3,13 +3,15 @@ from copy import deepcopy
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from tensordict.tensordict import TensorDict
from common import layers, math, init from common import layers, math, init
class WorldModel(nn.Module): class WorldModel(nn.Module):
""" """
TD-MPC2 implicit world model architecture. Distributed version of the TD-MPC2 world model architecture.
Can be used for both single-task and multi-task experiments. Can be used for both single-task and multi-task experiments.
""" """
@@ -17,24 +19,36 @@ class WorldModel(nn.Module):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
if cfg.multitask: if cfg.multitask:
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1) self.__task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim) self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
for i in range(len(cfg.tasks)): for i in range(len(cfg.tasks)):
self._action_masks[i, :cfg.action_dims[i]] = 1. self._action_masks[i, :cfg.action_dims[i]] = 1.
self._encoder = layers.enc(cfg) self.__encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self.__dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) self.__reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self.__pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.__Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init) self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params[-2]]) init.zero_([self.__reward[-1].weight, self.__Qs.params[-2]])
self._target_Qs = deepcopy(self._Qs).requires_grad_(False) self._target_Qs = deepcopy(self.__Qs).requires_grad_(False)
self.log_std_min = torch.tensor(cfg.log_std_min) self.log_std_min = torch.tensor(cfg.log_std_min, requires_grad=False)
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min self.log_std_dif = torch.tensor(cfg.log_std_max, requires_grad=False) - self.log_std_min
self.to(cfg.rank)
if cfg.multitask:
self._task_emb = DDP(self.__task_emb, device_ids=[cfg.rank])
self._encoder = nn.ModuleDict({k: DDP(v, device_ids=[cfg.rank]) for k, v in self.__encoder.items()})
self._dynamics = DDP(self.__dynamics, device_ids=[cfg.rank])
self._reward = DDP(self.__reward, device_ids=[cfg.rank])
self._pi = DDP(self.__pi, device_ids=[cfg.rank])
self._Qs = DDP(self.__Qs, device_ids=[cfg.rank])
@property @property
def total_params(self): def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters() if p.requires_grad)
def __repr__(self):
modules = '\n'.join([str(m) for m in [self._encoder, self._dynamics, self._reward, self._pi, self._Qs]])
return f"{self.__class__.__name__}({modules})\nLearnable parameters: {self.total_params:,}"
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
""" """

View File

@@ -11,6 +11,7 @@ eval_episodes: 10
eval_freq: 50000 eval_freq: 50000
# training # training
world_size: 1
steps: 10_000_000 steps: 10_000_000
batch_size: 256 batch_size: 256
reward_coef: 0.1 reward_coef: 0.1
@@ -74,6 +75,7 @@ save_agent: true
seed: 1 seed: 1
# convenience # convenience
rank: ???
work_dir: ??? work_dir: ???
task_title: ??? task_title: ???
multitask: ??? multitask: ???

View File

@@ -35,7 +35,8 @@ def make_multitask_env(cfg):
""" """
Make a multi-task environment for TD-MPC2 experiments. Make a multi-task environment for TD-MPC2 experiments.
""" """
print('Creating multi-task environment with tasks:', cfg.tasks) if cfg.rank == 0:
print('Creating multi-task environment with tasks:', cfg.tasks)
envs = [] envs = []
for task in cfg.tasks: for task in cfg.tasks:
_cfg = deepcopy(cfg) _cfg = deepcopy(cfg)

View File

@@ -44,7 +44,7 @@ def evaluate(cfg: dict):
cfg = parse_cfg(cfg) cfg = parse_cfg(cfg)
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold'])) print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold']))
print(colored(f'Model size: {cfg.model_size}', 'blue', attrs=['bold'])) print(colored(f'Model size: {cfg.get("model_size", "default")}', 'blue', attrs=['bold']))
print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold'])) print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold']))
if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint): if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint):
print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold'])) print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold']))

View File

@@ -16,8 +16,8 @@ class TDMPC2:
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self.device = torch.device('cuda') self.device = torch.device(cfg.rank)
self.model = WorldModel(cfg).to(self.device) self.model = WorldModel(cfg)
self.optim = torch.optim.Adam([ self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()}, {'params': self.model._dynamics.parameters()},
@@ -30,7 +30,7 @@ class TDMPC2:
self.scale = RunningScale(cfg) self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
self.discount = torch.tensor( self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda' [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=cfg.rank
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
def _get_discount(self, episode_length): def _get_discount(self, episode_length):

View File

@@ -14,14 +14,28 @@ from common.buffer import Buffer
from envs import make_env from envs import make_env
from tdmpc2 import TDMPC2 from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer from trainer.offline_trainer import OfflineTrainer
from trainer.online_trainer import OnlineTrainer
from common.logger import Logger from common.logger import Logger
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@hydra.main(config_name='config', config_path='.') def setup(rank, world_size):
def train(cfg: dict): os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# initialize the process group
torch.distributed.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size
)
def cleanup():
torch.distributed.destroy_process_group()
def train(rank: int, cfg: dict):
""" """
Script for training single-task / multi-task TD-MPC2 agents. Script for training single-task / multi-task TD-MPC2 agents.
@@ -40,14 +54,11 @@ def train(cfg: dict):
$ python train.py task=dog-run steps=7000000 $ python train.py task=dog-run steps=7000000
``` ```
""" """
assert torch.cuda.is_available() setup(rank, cfg.world_size)
assert cfg.steps > 0, 'Must train for at least 1 step.' set_seed(cfg.seed + rank)
cfg = parse_cfg(cfg) cfg.rank = rank
set_seed(cfg.seed)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer = OfflineTrainer(
trainer = trainer_cls(
cfg=cfg, cfg=cfg,
env=make_env(cfg), env=make_env(cfg),
agent=TDMPC2(cfg), agent=TDMPC2(cfg),
@@ -55,8 +66,26 @@ def train(cfg: dict):
logger=Logger(cfg), logger=Logger(cfg),
) )
trainer.train() trainer.train()
print('\nTraining completed successfully') if cfg.rank == 0:
print('\nTraining completed successfully')
cleanup()
@hydra.main(config_name='config', config_path='.')
def launch(cfg: dict):
assert torch.cuda.is_available()
assert cfg.world_size > 0, 'Must train with at least 1 GPU.'
assert cfg.task in {'mt30', 'mt80'}, 'Distributed training is only supported for multi-task experiments.'
assert cfg.steps > 0, 'Must train for at least 1 step.'
cfg = parse_cfg(cfg)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
torch.multiprocessing.spawn(
train,
args=(cfg,),
nprocs=cfg.world_size,
join=True,
)
if __name__ == '__main__': if __name__ == '__main__':
train() launch()

View File

@@ -7,8 +7,9 @@ class Trainer:
self.agent = agent self.agent = agent
self.buffer = buffer self.buffer = buffer
self.logger = logger self.logger = logger
print("Learnable parameters: {:,}".format(self.agent.model.total_params)) if cfg.rank == 0:
print('Architecture:', self.agent.model) print('Architecture:', self.agent.model)
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
def eval(self): def eval(self):
"""Evaluate a TD-MPC2 agent.""" """Evaluate a TD-MPC2 agent."""

View File

@@ -50,12 +50,21 @@ class OfflineTrainer(Trainer):
fp = Path(os.path.join(self.cfg.data_dir, '*.pt')) fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
fps = sorted(glob(str(fp))) fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}' assert len(fps) > 0, f'No data found at {fp}'
print(f'Found {len(fps)} files in {fp}') if self.cfg.rank == 0:
print(f'Found {len(fps)} files in {fp}')
# Distribute data across processes
assert len(fps) >= self.cfg.world_size, \
f'World size {self.cfg.world_size} cannot be greater than number of data chunks {len(fps)}'
fps = fps[self.cfg.rank::self.cfg.world_size]
print(f'Process {self.cfg.rank} has {len(fps)} files')
assert len(fps) > 0, f'No data assigned to process {self.cfg.rank}'
# Create buffer for sampling # Create buffer for sampling
_cfg = deepcopy(self.cfg) _cfg = deepcopy(self.cfg)
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501 _cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000 _cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
_cfg.buffer_size //= self.cfg.world_size
_cfg.steps = _cfg.buffer_size _cfg.steps = _cfg.buffer_size
self.buffer = Buffer(_cfg) self.buffer = Buffer(_cfg)
for fp in tqdm(fps, desc='Loading data'): for fp in tqdm(fps, desc='Loading data'):
@@ -65,10 +74,12 @@ class OfflineTrainer(Trainer):
f'please double-check your config.' f'please double-check your config.'
for i in range(len(td)): for i in range(len(td)):
self.buffer.add(td[i]) self.buffer.add(td[i])
assert self.buffer.num_eps == self.buffer.capacity, \ if self.buffer.num_transitions > self.buffer.capacity:
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.' print(f'Buffer has {self.buffer.num_transitions} transitions,' \
f'expected maximum {self.buffer.capacity} transitions in process {self.cfg.rank}.')
print(f'Training agent for {self.cfg.steps} iterations...') if self.cfg.rank == 0:
print(f'Training agent for {self.cfg.steps} iterations...')
metrics = {} metrics = {}
for i in range(self.cfg.steps): for i in range(self.cfg.steps):
@@ -76,7 +87,7 @@ class OfflineTrainer(Trainer):
train_metrics = self.agent.update(self.buffer) train_metrics = self.agent.update(self.buffer)
# Evaluate agent periodically # Evaluate agent periodically
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0: if self.cfg.rank == 0 and (i % self.cfg.eval_freq == 0 or i % 10_000 == 0):
metrics = { metrics = {
'iteration': i, 'iteration': i,
'total_time': time() - self._start_time, 'total_time': time() - self._start_time,
@@ -89,4 +100,5 @@ class OfflineTrainer(Trainer):
self.logger.save_agent(self.agent, identifier=f'{i}') self.logger.save_agent(self.agent, identifier=f'{i}')
self.logger.log(metrics, 'pretrain') self.logger.log(metrics, 'pretrain')
self.logger.finish(self.agent) if self.cfg.rank == 0:
self.logger.finish(self.agent)