Compare commits
15 Commits
uncertaint
...
distribute
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1afbccb05 | ||
|
|
c218c0ff1b | ||
|
|
d3bff48d58 | ||
|
|
c16f2557bb | ||
|
|
de87519c60 | ||
|
|
57158282b4 | ||
|
|
718966c28d | ||
|
|
01cdf0f799 | ||
|
|
02b18a48b1 | ||
|
|
e8f1ed6785 | ||
|
|
8b6fe61bed | ||
|
|
aa9c6f33f5 | ||
|
|
20f4064dfa | ||
|
|
c6d1bd85bf | ||
|
|
33555b5982 |
41
README.md
41
README.md
@@ -2,13 +2,13 @@
|
||||
|
||||
Official implementation of
|
||||
|
||||
[TD-MPC2: Scalable, Robust World Models for Continuous Control](https://nicklashansen.github.io/td-mpc2) by
|
||||
[TD-MPC2: Scalable, Robust World Models for Continuous Control](https://www.tdmpc2.com) by
|
||||
|
||||
[Nicklas Hansen](https://nicklashansen.github.io/), [Hao Su](https://cseweb.ucsd.edu/~haosu/)\*, [Xiaolong Wang](https://xiaolonw.github.io/)\* (UC San Diego)</br>
|
||||
[Nicklas Hansen](https://nicklashansen.github.io), [Hao Su](https://cseweb.ucsd.edu/~haosu)\*, [Xiaolong Wang](https://xiaolonw.github.io)\* (UC San Diego)</br>
|
||||
|
||||
<img src="assets/0.gif" width="12.5%"><img src="assets/1.gif" width="12.5%"><img src="assets/2.gif" width="12.5%"><img src="assets/3.gif" width="12.5%"><img src="assets/4.gif" width="12.5%"><img src="assets/5.gif" width="12.5%"><img src="assets/6.gif" width="12.5%"><img src="assets/7.gif" width="12.5%"></br>
|
||||
|
||||
[[Website]](https://nicklashansen.github.io/td-mpc2) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://nicklashansen.github.io/td-mpc2/models) [[Dataset]](https://nicklashansen.github.io/td-mpc2/dataset)
|
||||
[[Website]](https://www.tdmpc2.com) [[Paper]](https://arxiv.org/abs/2310.16828) [[Models]](https://www.tdmpc2.com/models) [[Dataset]](https://www.tdmpc2.com/dataset)
|
||||
|
||||
----
|
||||
|
||||
@@ -18,7 +18,7 @@ TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm.
|
||||
|
||||
<img src="assets/8.png" width="100%" style="max-width: 640px"><br/>
|
||||
|
||||
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://nicklashansen.github.io/td-mpc2/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://nicklashansen.github.io/td-mpc2/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
|
||||
This repository contains code for training and evaluating both single-task online RL and multi-task offline RL TD-MPC**2** agents. We additionally open-source **300+** [model checkpoints](https://www.tdmpc2.com/models) (including 12 multi-task models) across 4 task domains: [DMControl](https://arxiv.org/abs/1801.00690), [Meta-World](https://meta-world.github.io/), [ManiSkill2](https://maniskill2.github.io/), and [MyoSuite](https://sites.google.com/view/myosuite), as well as our [30-task and 80-task datasets](https://www.tdmpc2.com/dataset) used to train the multi-task models. Our codebase supports both state and pixel observations. We hope that this repository will serve as a useful community resource for future research on model-based RL.
|
||||
|
||||
----
|
||||
|
||||
@@ -29,17 +29,19 @@ You will need a machine with a GPU and at least 12 GB of RAM for single-task onl
|
||||
We provide a `Dockerfile` for easy installation. You can build the docker image by running
|
||||
|
||||
```
|
||||
cd docker && docker build . -t <user>/tdmpc2:0.1.0
|
||||
cd docker && docker build . -t <user>/tdmpc2:1.0.0
|
||||
```
|
||||
|
||||
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running one of the following commands:
|
||||
This docker image contains all dependencies needed for running DMControl, Meta-World, and ManiSkill2 experiments.
|
||||
|
||||
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running the following command:
|
||||
|
||||
```
|
||||
conda env create -f docker/environment.yaml
|
||||
conda env create -f docker/environment_minimal.yaml
|
||||
pip install gym==0.21.0
|
||||
```
|
||||
|
||||
The `environment.yaml` file installs dependencies required for all environments, whereas `environment_minimal.yaml` only installs dependencies for training on DMControl tasks.
|
||||
The `environment.yaml` file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in `environment.yaml`.
|
||||
|
||||
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
|
||||
|
||||
@@ -118,16 +120,23 @@ We recommend using default hyperparameters for single-task online RL, including
|
||||
|
||||
## Citation
|
||||
|
||||
If you find our work useful, please consider citing the paper as follows:
|
||||
If you find our work useful, please consider citing our paper as follows:
|
||||
|
||||
```
|
||||
@misc{hansen2023tdmpc2,
|
||||
title={TD-MPC2: Scalable, Robust World Models for Continuous Control},
|
||||
author={Nicklas Hansen and Hao Su and Xiaolong Wang},
|
||||
year={2023},
|
||||
eprint={2310.16828},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG}
|
||||
@inproceedings{hansen2024tdmpc2,
|
||||
title={TD-MPC2: Scalable, Robust World Models for Continuous Control},
|
||||
author={Nicklas Hansen and Hao Su and Xiaolong Wang},
|
||||
booktitle={International Conference on Learning Representations (ICLR)},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
as well as the original TD-MPC paper:
|
||||
```
|
||||
@inproceedings{hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={International Conference on Machine Learning (ICML)},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
##########################################
|
||||
# Dockerfile for TD-MPC2 #
|
||||
# TD-MPC2 Anonymous Authors, 2023 (c) #
|
||||
# Nicklas Hansen, 2023 (c) #
|
||||
# https://www.tdmpc2.com #
|
||||
# -------------------------------------- #
|
||||
# Instructions: #
|
||||
# docker build . -t <user>/tdmpc2:0.1.0 #
|
||||
# docker push <user>/tdmpc2:0.1.0 #
|
||||
# Build instructions: #
|
||||
# docker build . -t <user>/tdmpc2:1.0.0 #
|
||||
# docker push <user>/tdmpc2:1.0.0 #
|
||||
# -------------------------------------- #
|
||||
# Run: #
|
||||
# docker run -i \ #
|
||||
# -v <path>/<to>/tdmpc2:/tdmpc2 \ #
|
||||
# --gpus all \ #
|
||||
# -t <user>/tdmpc2:1.0.0 \ #
|
||||
# /bin/bash #
|
||||
##########################################
|
||||
|
||||
# base image
|
||||
@@ -36,19 +44,15 @@ SHELL ["/bin/bash", "-c"]
|
||||
# conda environment
|
||||
COPY nvidia_icd.json /usr/share/vulkan/icd.d/nvidia_icd.json
|
||||
COPY environment.yaml /root
|
||||
RUN conda env update -n base -f /root/environment.yaml && \
|
||||
RUN conda update conda && \
|
||||
conda env update -n base -f /root/environment.yaml && \
|
||||
rm /root/environment.yaml && \
|
||||
cd /root && \
|
||||
python -m mani_skill2.utils.download_asset all -y && \
|
||||
conda clean -ya && \
|
||||
pip cache purge
|
||||
|
||||
# environment variables
|
||||
# mujoco 2.1.0
|
||||
ENV MUJOCO_GL egl
|
||||
ENV MS2_ASSET_DIR /root/data
|
||||
ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
|
||||
|
||||
# mujoco (required for metaworld)
|
||||
RUN mkdir -p /root/.mujoco && \
|
||||
wget https://www.tdmpc2.com/files/mjkey.txt && \
|
||||
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \
|
||||
@@ -56,4 +60,23 @@ RUN mkdir -p /root/.mujoco && \
|
||||
rm mujoco210-linux-x86_64.tar.gz && \
|
||||
mv mujoco210 /root/.mujoco/mujoco210 && \
|
||||
mv mjkey.txt /root/.mujoco/mjkey.txt && \
|
||||
find /root/.mujoco -uid 421709 -exec chown root:root {} \; && \
|
||||
python -c "import mujoco_py"
|
||||
|
||||
# gym
|
||||
RUN pip install gym==0.21.0
|
||||
|
||||
# metaworld
|
||||
RUN pip install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
||||
|
||||
# maniskill2
|
||||
ENV MS2_ASSET_DIR /root/data
|
||||
RUN pip install mani-skill2==0.4.1 && \
|
||||
cd /root && \
|
||||
python -m mani_skill2.utils.download_asset all -y
|
||||
|
||||
# myosuite (conflicts with meta-world / mani-skill2)
|
||||
# RUN pip install myosuite
|
||||
|
||||
# success!
|
||||
RUN echo "Successfully built TD-MPC2 Docker image!"
|
||||
|
||||
@@ -5,49 +5,62 @@ channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.0
|
||||
- pytorch
|
||||
- torchvision
|
||||
- cudatoolkit=11.7
|
||||
- glew
|
||||
- glib
|
||||
- pip==21
|
||||
- glew=2.1.0
|
||||
- glib=2.68.4
|
||||
- pip=21.0
|
||||
- python=3.9.0
|
||||
- pytorch>=2.2.2
|
||||
- torchvision>=0.16.2
|
||||
- pip:
|
||||
- absl-py
|
||||
- glfw
|
||||
- kornia
|
||||
- termcolor
|
||||
- gym==0.21.0
|
||||
- moviepy
|
||||
- ffmpeg
|
||||
- imageio
|
||||
- imageio-ffmpeg
|
||||
- omegaconf
|
||||
- hydra-core
|
||||
- hydra-submitit-launcher
|
||||
- submitit
|
||||
- pandas
|
||||
- patchelf
|
||||
- protobuf
|
||||
- tqdm
|
||||
- transforms3d
|
||||
- joblib
|
||||
- opencv-python
|
||||
- opencv-contrib-python
|
||||
- filelock
|
||||
- sapien==2.2.1
|
||||
- mani-skill2==0.4.1
|
||||
- trimesh
|
||||
- open3d
|
||||
- setuptools==65.5.0
|
||||
- absl-py==2.0.0
|
||||
- "cython<3"
|
||||
- dm-control==1.0.8
|
||||
- ffmpeg==1.4
|
||||
- glfw==2.6.4
|
||||
- hydra-core==1.3.2
|
||||
- hydra-submitit-launcher==1.2.0
|
||||
- imageio==2.33.1
|
||||
- imageio-ffmpeg==0.4.9
|
||||
- kornia==0.7.1
|
||||
- moviepy==1.0.3
|
||||
- mujoco==2.3.1
|
||||
- mujoco-py==2.1.2.14
|
||||
- dm-control
|
||||
- pillow
|
||||
- pyquaternion
|
||||
- git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
||||
# - myosuite # MyoSuite requires gym==0.13.0 which conflicts with Meta-World & ManiSkill2, install separately if needed
|
||||
- tensordict-nightly
|
||||
- torchrl-nightly
|
||||
- wandb
|
||||
- numpy==1.23.5
|
||||
- omegaconf==2.3.0
|
||||
- open3d==0.18.0
|
||||
- opencv-contrib-python==4.9.0.80
|
||||
- opencv-python==4.9.0.80
|
||||
- pandas==2.1.4
|
||||
- sapien==2.2.1
|
||||
- submitit==1.5.1
|
||||
- setuptools==65.5.0
|
||||
- patchelf==0.17.2.1
|
||||
- protobuf==4.25.2
|
||||
- pillow==10.2.0
|
||||
- pyquaternion==0.9.9
|
||||
- tensordict-nightly==2024.3.26
|
||||
- termcolor==2.4.0
|
||||
- torchrl-nightly==2024.3.26
|
||||
- transforms3d==0.4.1
|
||||
- trimesh==4.0.9
|
||||
- tqdm==4.66.1
|
||||
- wandb==0.16.2
|
||||
- wheel==0.38.0
|
||||
####################
|
||||
# Gym:
|
||||
# (unmaintained but required for maniskill2/meta-world/myosuite)
|
||||
# - gym==0.21.0
|
||||
####################
|
||||
# ManiSkill2:
|
||||
# (requires gym==0.21.0 which occasionally breaks)
|
||||
# - mani-skill2==0.4.1
|
||||
####################
|
||||
# Meta-World:
|
||||
# (requires gym==0.21.0 which occasionally breaks)
|
||||
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
||||
####################
|
||||
# MyoSuite:
|
||||
# (requires gym==0.13 which conflicts with meta-world / mani-skill2)
|
||||
# - myosuite
|
||||
####################
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
name: tdmpc2
|
||||
channels:
|
||||
- pytorch-nightly
|
||||
- nvidia
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.9.0
|
||||
- pytorch
|
||||
- torchvision
|
||||
- cudatoolkit=11.7
|
||||
- glew
|
||||
- glib
|
||||
- pip==21
|
||||
- pip:
|
||||
- absl-py
|
||||
- glfw
|
||||
- kornia
|
||||
- termcolor
|
||||
- gym==0.21.0
|
||||
- moviepy
|
||||
- ffmpeg
|
||||
- imageio
|
||||
- imageio-ffmpeg
|
||||
- omegaconf
|
||||
- hydra-core
|
||||
- hydra-submitit-launcher
|
||||
- submitit
|
||||
- pandas
|
||||
- patchelf
|
||||
- protobuf
|
||||
- tqdm
|
||||
- setuptools==65.5.0
|
||||
- "cython<3"
|
||||
- dm-control
|
||||
- pillow
|
||||
- tensordict-nightly
|
||||
- torchrl-nightly
|
||||
- wandb
|
||||
@@ -1,8 +1,7 @@
|
||||
import torch
|
||||
from tensordict.tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers import ReplayBuffer, LazyTensorStorage
|
||||
|
||||
from common.samplers import SliceSampler
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
|
||||
|
||||
class Buffer():
|
||||
@@ -13,16 +12,18 @@ class Buffer():
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self._device = torch.device('cuda')
|
||||
self._device = torch.device(self.cfg.rank)
|
||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||
self._sampler = SliceSampler(
|
||||
num_slices=self.cfg.batch_size,
|
||||
end_key=None,
|
||||
traj_key='episode',
|
||||
truncated_key=None,
|
||||
strict_length=True,
|
||||
)
|
||||
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||
self._num_eps = 0
|
||||
self._num_transitions = 0
|
||||
|
||||
@property
|
||||
def capacity(self):
|
||||
@@ -33,6 +34,11 @@ class Buffer():
|
||||
def num_eps(self):
|
||||
"""Return the number of episodes in the buffer."""
|
||||
return self._num_eps
|
||||
|
||||
@property
|
||||
def num_transitions(self):
|
||||
"""Return the number of transitions in the buffer."""
|
||||
return self._num_transitions
|
||||
|
||||
def _reserve_buffer(self, storage):
|
||||
"""
|
||||
@@ -48,7 +54,11 @@ class Buffer():
|
||||
|
||||
def _init(self, tds):
|
||||
"""Initialize the replay buffer. Use the first episode to estimate storage requirements."""
|
||||
print(f'Buffer capacity: {self._capacity:,}')
|
||||
if self.cfg.rank == 0:
|
||||
if self.cfg.world_size > 1:
|
||||
print(f'Buffer capacity per process: {self._capacity:,}')
|
||||
else:
|
||||
print(f'Buffer capacity: {self._capacity:,}')
|
||||
mem_free, _ = torch.cuda.mem_get_info()
|
||||
bytes_per_step = sum([
|
||||
(v.numel()*v.element_size() if not isinstance(v, TensorDict) \
|
||||
@@ -56,10 +66,15 @@ class Buffer():
|
||||
for v in tds.values()
|
||||
]) / len(tds)
|
||||
total_bytes = bytes_per_step*self._capacity
|
||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||
if self.cfg.rank == 0:
|
||||
if self.cfg.world_size > 1:
|
||||
print(f'Storage required per process: {total_bytes/1e9:.2f} GB')
|
||||
else:
|
||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||
# Heuristic: decide whether to use CUDA or CPU memory
|
||||
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
|
||||
print(f'Using {storage_device.upper()} memory for storage.')
|
||||
storage_device = self.cfg.rank if 2.5*total_bytes < mem_free else 'cpu'
|
||||
if self.cfg.rank == 0:
|
||||
print(f'Using {storage_device.upper()} memory for storage.')
|
||||
return self._reserve_buffer(
|
||||
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
|
||||
)
|
||||
@@ -88,6 +103,7 @@ class Buffer():
|
||||
self._buffer = self._init(td)
|
||||
self._buffer.extend(td)
|
||||
self._num_eps += 1
|
||||
self._num_transitions += len(td)
|
||||
return self._num_eps
|
||||
|
||||
def sample(self):
|
||||
|
||||
@@ -113,11 +113,13 @@ class Logger:
|
||||
self._group = cfg_to_group(cfg)
|
||||
self._seed = cfg.seed
|
||||
self._eval = []
|
||||
print_run(cfg)
|
||||
if cfg.rank == 0:
|
||||
print_run(cfg)
|
||||
self.project = cfg.get("wandb_project", "none")
|
||||
self.entity = cfg.get("wandb_entity", "none")
|
||||
if cfg.disable_wandb or self.project == "none" or self.entity == "none":
|
||||
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
|
||||
if cfg.rank == 0 or cfg.disable_wandb or self.project == "none" or self.entity == "none":
|
||||
if cfg.rank == 0:
|
||||
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
|
||||
cfg.save_agent = False
|
||||
cfg.save_video = False
|
||||
self._wandb = None
|
||||
@@ -227,8 +229,10 @@ class Logger:
|
||||
xkey = "step"
|
||||
elif category == "pretrain":
|
||||
xkey = "iteration"
|
||||
_d = dict()
|
||||
for k, v in d.items():
|
||||
self._wandb.log({category + "/" + k: v}, step=d[xkey])
|
||||
_d[category + "/" + k] = v
|
||||
self._wandb.log(_d, step=d[xkey])
|
||||
if category == "eval" and self._save_csv:
|
||||
keys = ["step", "episode_reward"]
|
||||
self._eval.append(np.array([d[keys[0]], d[keys[1]]]))
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from tensordict.utils import NestedKey
|
||||
|
||||
from torchrl.data.replay_buffers.storages import Storage, TensorStorage
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
|
||||
|
||||
# Source: https://pytorch.org/rl/reference/generated/torchrl.data.replay_buffers.SliceSampler.html
|
||||
# This copy will live here until it has been included in a few torchrl stable releases
|
||||
|
||||
|
||||
class SliceSampler(Sampler):
|
||||
"""Samples slices of data along the first dimension, given start and stop signals.
|
||||
|
||||
This class samples sub-trajectories with replacement. For a version without
|
||||
replacement, see :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement`.
|
||||
|
||||
Keyword Args:
|
||||
num_slices (int): the number of slices to be sampled. The batch-size
|
||||
must be greater or equal to the ``num_slices`` argument. Exclusive
|
||||
with ``slice_len``.
|
||||
slice_len (int): the length of the slices to be sampled. The batch-size
|
||||
must be greater or equal to the ``slice_len`` argument and divisible
|
||||
by it. Exclusive with ``num_slices``.
|
||||
end_key (NestedKey, optional): the key indicating the end of a
|
||||
trajectory (or episode). Defaults to ``("next", "done")``.
|
||||
traj_key (NestedKey, optional): the key indicating the trajectories.
|
||||
Defaults to ``"episode"`` (commonly used across datasets in TorchRL).
|
||||
cache_values (bool, optional): to be used with static datasets.
|
||||
Will cache the start and end signal of the trajectory.
|
||||
truncated_key (NestedKey, optional): If not ``None``, this argument
|
||||
indicates where a truncated signal should be written in the output
|
||||
data. This is used to indicate to value estimators where the provided
|
||||
trajectory breaks. Defaults to ``("next", "truncated")``.
|
||||
This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer`
|
||||
instances (otherwise the truncated key is returned in the info dictionary
|
||||
returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method).
|
||||
strict_length (bool, optional): if ``False``, trajectories of length
|
||||
shorter than `slice_len` (or `batch_size // num_slices`) will be
|
||||
allowed to appear in the batch.
|
||||
Be mindful that this can result in effective `batch_size` shorter
|
||||
than the one asked for! Trajectories can be split using
|
||||
:func:`torchrl.collectors.split_trajectories`. Defaults to ``True``.
|
||||
|
||||
.. note:: To recover the trajectory splits in the storage,
|
||||
:class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first
|
||||
attempt to find the ``traj_key`` entry in the storage. If it cannot be
|
||||
found, the ``end_key`` will be used to reconstruct the episodes.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from tensordict import TensorDict
|
||||
>>> from torchrl.data.replay_buffers import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
>>> from torchrl.data.replay_buffers.samplers import SliceSampler
|
||||
>>> torch.manual_seed(0)
|
||||
>>> rb = TensorDictReplayBuffer(
|
||||
... storage=LazyMemmapStorage(1_000_000),
|
||||
... sampler=SliceSampler(cache_values=True, num_slices=10),
|
||||
... batch_size=320,
|
||||
... )
|
||||
>>> episode = torch.zeros(1000, dtype=torch.int)
|
||||
>>> episode[:300] = 1
|
||||
>>> episode[300:550] = 2
|
||||
>>> episode[550:700] = 3
|
||||
>>> episode[700:] = 4
|
||||
>>> data = TensorDict(
|
||||
... {
|
||||
... "episode": episode,
|
||||
... "obs": torch.randn((3, 4, 5)).expand(1000, 3, 4, 5),
|
||||
... "act": torch.randn((20,)).expand(1000, 20),
|
||||
... "other": torch.randn((20, 50)).expand(1000, 20, 50),
|
||||
... }, [1000]
|
||||
... )
|
||||
>>> rb.extend(data)
|
||||
>>> sample = rb.sample()
|
||||
>>> print("sample:", sample)
|
||||
>>> print("episodes", sample.get("episode").unique())
|
||||
episodes tensor([1, 2, 3, 4], dtype=torch.int32)
|
||||
|
||||
:class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with
|
||||
most of TorchRL's datasets:
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> from torchrl.data.datasets import RobosetExperienceReplay
|
||||
>>> from torchrl.data import SliceSampler
|
||||
>>>
|
||||
>>> torch.manual_seed(0)
|
||||
>>> num_slices = 10
|
||||
>>> dataid = list(RobosetExperienceReplay.available_datasets)[0]
|
||||
>>> data = RobosetExperienceReplay(dataid, batch_size=320, sampler=SliceSampler(num_slices=num_slices))
|
||||
>>> for batch in data:
|
||||
... batch = batch.reshape(num_slices, -1)
|
||||
... break
|
||||
>>> print("check that each batch only has one episode:", batch["episode"].unique(dim=1))
|
||||
check that each batch only has one episode: tensor([[19],
|
||||
[14],
|
||||
[ 8],
|
||||
[10],
|
||||
[13],
|
||||
[ 4],
|
||||
[ 2],
|
||||
[ 3],
|
||||
[22],
|
||||
[ 8]])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_slices: int = None,
|
||||
slice_len: int = None,
|
||||
end_key: NestedKey | None = None,
|
||||
traj_key: NestedKey | None = None,
|
||||
cache_values: bool = False,
|
||||
truncated_key: NestedKey | None = ("next", "truncated"),
|
||||
strict_length: bool = True,
|
||||
) -> object:
|
||||
if end_key is None:
|
||||
end_key = ("next", "done")
|
||||
if traj_key is None:
|
||||
traj_key = "episode"
|
||||
if not ((num_slices is None) ^ (slice_len is None)):
|
||||
raise TypeError(
|
||||
"Either num_slices or slice_len must be not None, and not both. "
|
||||
f"Got num_slices={num_slices} and slice_len={slice_len}."
|
||||
)
|
||||
self.num_slices = num_slices
|
||||
self.slice_len = slice_len
|
||||
self.end_key = end_key
|
||||
self.traj_key = traj_key
|
||||
self.truncated_key = truncated_key
|
||||
self.cache_values = cache_values
|
||||
self._fetch_traj = True
|
||||
self._uses_data_prefix = False
|
||||
self.strict_length = strict_length
|
||||
self._cache = {}
|
||||
|
||||
@staticmethod
|
||||
def _find_start_stop_traj(*, trajectory=None, end=None):
|
||||
if trajectory is not None:
|
||||
# slower
|
||||
# _, stop_idx = torch.unique_consecutive(trajectory, return_counts=True)
|
||||
# stop_idx = stop_idx.cumsum(0) - 1
|
||||
|
||||
# even slower
|
||||
# t = trajectory.unsqueeze(0)
|
||||
# w = torch.tensor([1, -1], dtype=torch.int).view(1, 1, 2)
|
||||
# stop_idx = torch.conv1d(t, w).nonzero()
|
||||
|
||||
# faster
|
||||
end = trajectory[:-1] != trajectory[1:]
|
||||
end = torch.cat([end, torch.ones_like(end[:1])], 0)
|
||||
else:
|
||||
end = torch.index_fill(
|
||||
end,
|
||||
index=torch.tensor(-1, device=end.device, dtype=torch.long),
|
||||
dim=0,
|
||||
value=1,
|
||||
)
|
||||
if end.ndim != 1:
|
||||
raise RuntimeError(
|
||||
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
|
||||
)
|
||||
stop_idx = end.view(-1).nonzero().view(-1)
|
||||
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
|
||||
lengths = stop_idx - start_idx + 1
|
||||
return start_idx, stop_idx, lengths
|
||||
|
||||
def _tensor_slices_from_startend(self, seq_length, start):
|
||||
if isinstance(seq_length, int):
|
||||
return (
|
||||
torch.arange(
|
||||
seq_length, device=start.device, dtype=start.dtype
|
||||
).unsqueeze(0)
|
||||
+ start.unsqueeze(1)
|
||||
).view(-1)
|
||||
else:
|
||||
# when padding is needed
|
||||
return torch.cat(
|
||||
[
|
||||
_start
|
||||
+ torch.arange(_seq_len, device=start.device, dtype=start.dtype)
|
||||
for _start, _seq_len in zip(start, seq_length)
|
||||
]
|
||||
)
|
||||
|
||||
def _get_stop_and_length(self, storage, fallback=True):
|
||||
if self.cache_values and "stop-and-length" in self._cache:
|
||||
return self._cache.get("stop-and-length")
|
||||
|
||||
if self._fetch_traj:
|
||||
# We first try with the traj_key
|
||||
try:
|
||||
# In some cases, the storage hides the data behind "_data".
|
||||
# In the future, this may be deprecated, and we don't want to mess
|
||||
# with the keys provided by the user so we fall back on a proxy to
|
||||
# the traj key.
|
||||
try:
|
||||
trajectory = storage._storage.get(self._used_traj_key)
|
||||
except KeyError:
|
||||
trajectory = storage._storage.get(("_data", self.traj_key))
|
||||
# cache that value for future use
|
||||
self._used_traj_key = ("_data", self.traj_key)
|
||||
self._uses_data_prefix = (
|
||||
isinstance(self._used_traj_key, tuple)
|
||||
and self._used_traj_key[0] == "_data"
|
||||
)
|
||||
vals = self._find_start_stop_traj(trajectory=trajectory[: len(storage)])
|
||||
if self.cache_values:
|
||||
self._cache["stop-and-length"] = vals
|
||||
return vals
|
||||
except KeyError:
|
||||
if fallback:
|
||||
self._fetch_traj = False
|
||||
return self._get_stop_and_length(storage, fallback=False)
|
||||
raise
|
||||
|
||||
else:
|
||||
try:
|
||||
# In some cases, the storage hides the data behind "_data".
|
||||
# In the future, this may be deprecated, and we don't want to mess
|
||||
# with the keys provided by the user so we fall back on a proxy to
|
||||
# the traj key.
|
||||
try:
|
||||
done = storage._storage.get(self._used_end_key)
|
||||
except KeyError:
|
||||
done = storage._storage.get(("_data", self.end_key))
|
||||
# cache that value for future use
|
||||
self._used_end_key = ("_data", self.end_key)
|
||||
self._uses_data_prefix = (
|
||||
isinstance(self._used_end_key, tuple)
|
||||
and self._used_end_key[0] == "_data"
|
||||
)
|
||||
vals = self._find_start_stop_traj(end=done.squeeze())[: len(storage)]
|
||||
if self.cache_values:
|
||||
self._cache["stop-and-length"] = vals
|
||||
return vals
|
||||
except KeyError:
|
||||
if fallback:
|
||||
self._fetch_traj = True
|
||||
return self._get_stop_and_length(storage, fallback=False)
|
||||
raise
|
||||
|
||||
def _adjusted_batch_size(self, batch_size):
|
||||
if self.num_slices is not None:
|
||||
if batch_size % self.num_slices != 0:
|
||||
raise RuntimeError(
|
||||
f"The batch-size must be divisible by the number of slices, got batch_size={batch_size} and num_slices={self.num_slices}."
|
||||
)
|
||||
seq_length = batch_size // self.num_slices
|
||||
num_slices = self.num_slices
|
||||
else:
|
||||
if batch_size % self.slice_len != 0:
|
||||
raise RuntimeError(
|
||||
f"The batch-size must be divisible by the slice length, got batch_size={batch_size} and slice_len={self.slice_len}."
|
||||
)
|
||||
seq_length = self.slice_len
|
||||
num_slices = batch_size // self.slice_len
|
||||
return seq_length, num_slices
|
||||
|
||||
def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
|
||||
if not isinstance(storage, TensorStorage):
|
||||
raise RuntimeError(
|
||||
f"{type(self)} can only sample from TensorStorage subclasses, got {type(storage)} instead."
|
||||
)
|
||||
|
||||
# pick up as many trajs as we need
|
||||
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
|
||||
seq_length, num_slices = self._adjusted_batch_size(batch_size)
|
||||
return self._sample_slices(lengths, start_idx, seq_length, num_slices)
|
||||
|
||||
def _sample_slices(
|
||||
self, lengths, start_idx, seq_length, num_slices, traj_idx=None
|
||||
) -> Tuple[torch.Tensor, dict]:
|
||||
if (lengths < seq_length).any():
|
||||
if self.strict_length:
|
||||
raise RuntimeError(
|
||||
"Some stored trajectories have a length shorter than the slice that was asked for. "
|
||||
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
|
||||
"in you batch."
|
||||
)
|
||||
# make seq_length a tensor with values clamped by lengths
|
||||
seq_length = lengths.clamp_max(seq_length)
|
||||
|
||||
if traj_idx is None:
|
||||
traj_idx = torch.randint(
|
||||
lengths.shape[0], (num_slices,), device=lengths.device
|
||||
)
|
||||
else:
|
||||
num_slices = traj_idx.shape[0]
|
||||
relative_starts = (
|
||||
(
|
||||
torch.rand(num_slices, device=lengths.device)
|
||||
* (lengths[traj_idx] - seq_length + 1)
|
||||
)
|
||||
.floor()
|
||||
.to(start_idx.dtype)
|
||||
)
|
||||
starts = start_idx[traj_idx] + relative_starts
|
||||
index = self._tensor_slices_from_startend(seq_length, starts)
|
||||
if self.truncated_key is not None:
|
||||
truncated_key = self.truncated_key
|
||||
|
||||
truncated = torch.zeros(index.shape, dtype=torch.bool, device=index.device)
|
||||
if isinstance(seq_length, int):
|
||||
truncated.view(num_slices, -1)[:, -1] = 1
|
||||
else:
|
||||
truncated[seq_length.cumsum(0) - 1] = 1
|
||||
return index.to(torch.long), {truncated_key: truncated}
|
||||
return index.to(torch.long), {}
|
||||
|
||||
@property
|
||||
def _used_traj_key(self):
|
||||
return self.__dict__.get("__used_traj_key", self.traj_key)
|
||||
|
||||
@_used_traj_key.setter
|
||||
def _used_traj_key(self, value):
|
||||
self.__dict__["__used_traj_key"] = value
|
||||
|
||||
@property
|
||||
def _used_end_key(self):
|
||||
return self.__dict__.get("__used_end_key", self.end_key)
|
||||
|
||||
@_used_end_key.setter
|
||||
def _used_end_key(self, value):
|
||||
self.__dict__["__used_end_key"] = value
|
||||
|
||||
def _empty(self):
|
||||
pass
|
||||
|
||||
def dumps(self, path):
|
||||
# no op - cache does not need to be saved
|
||||
...
|
||||
|
||||
def loads(self, path):
|
||||
# no op
|
||||
...
|
||||
|
||||
def __getstate__(self):
|
||||
state = copy(self.__dict__)
|
||||
state["_cache"] = {}
|
||||
return state
|
||||
@@ -6,8 +6,8 @@ class RunningScale:
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))
|
||||
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))
|
||||
self._value = torch.ones(1, dtype=torch.float32, device=torch.device(cfg.rank))
|
||||
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device(cfg.rank))
|
||||
|
||||
def state_dict(self):
|
||||
return dict(value=self._value, percentiles=self._percentiles)
|
||||
|
||||
@@ -3,13 +3,15 @@ from copy import deepcopy
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tensordict.tensordict import TensorDict
|
||||
|
||||
from common import layers, math, init
|
||||
|
||||
|
||||
class WorldModel(nn.Module):
|
||||
"""
|
||||
TD-MPC2 implicit world model architecture.
|
||||
Distributed version of the TD-MPC2 world model architecture.
|
||||
Can be used for both single-task and multi-task experiments.
|
||||
"""
|
||||
|
||||
@@ -17,24 +19,36 @@ class WorldModel(nn.Module):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
if cfg.multitask:
|
||||
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
|
||||
self.__task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
|
||||
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
|
||||
for i in range(len(cfg.tasks)):
|
||||
self._action_masks[i, :cfg.action_dims[i]] = 1.
|
||||
self._encoder = layers.enc(cfg)
|
||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||
self.__encoder = layers.enc(cfg)
|
||||
self.__dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||
self.__reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||
self.__pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||
self.__Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||
self.apply(init.weight_init)
|
||||
init.zero_([self._reward[-1].weight, self._Qs.params[-2]])
|
||||
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
|
||||
self.log_std_min = torch.tensor(cfg.log_std_min)
|
||||
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min
|
||||
init.zero_([self.__reward[-1].weight, self.__Qs.params[-2]])
|
||||
self._target_Qs = deepcopy(self.__Qs).requires_grad_(False)
|
||||
self.log_std_min = torch.tensor(cfg.log_std_min, requires_grad=False)
|
||||
self.log_std_dif = torch.tensor(cfg.log_std_max, requires_grad=False) - self.log_std_min
|
||||
self.to(cfg.rank)
|
||||
if cfg.multitask:
|
||||
self._task_emb = DDP(self.__task_emb, device_ids=[cfg.rank])
|
||||
self._encoder = nn.ModuleDict({k: DDP(v, device_ids=[cfg.rank]) for k, v in self.__encoder.items()})
|
||||
self._dynamics = DDP(self.__dynamics, device_ids=[cfg.rank])
|
||||
self._reward = DDP(self.__reward, device_ids=[cfg.rank])
|
||||
self._pi = DDP(self.__pi, device_ids=[cfg.rank])
|
||||
self._Qs = DDP(self.__Qs, device_ids=[cfg.rank])
|
||||
|
||||
@property
|
||||
def total_params(self):
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
def __repr__(self):
|
||||
modules = '\n'.join([str(m) for m in [self._encoder, self._dynamics, self._reward, self._pi, self._Qs]])
|
||||
return f"{self.__class__.__name__}({modules})\nLearnable parameters: {self.total_params:,}"
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -11,6 +11,7 @@ eval_episodes: 10
|
||||
eval_freq: 50000
|
||||
|
||||
# training
|
||||
world_size: 1
|
||||
steps: 10_000_000
|
||||
batch_size: 256
|
||||
reward_coef: 0.1
|
||||
@@ -74,6 +75,7 @@ save_agent: true
|
||||
seed: 1
|
||||
|
||||
# convenience
|
||||
rank: ???
|
||||
work_dir: ???
|
||||
task_title: ???
|
||||
multitask: ???
|
||||
|
||||
@@ -35,7 +35,8 @@ def make_multitask_env(cfg):
|
||||
"""
|
||||
Make a multi-task environment for TD-MPC2 experiments.
|
||||
"""
|
||||
print('Creating multi-task environment with tasks:', cfg.tasks)
|
||||
if cfg.rank == 0:
|
||||
print('Creating multi-task environment with tasks:', cfg.tasks)
|
||||
envs = []
|
||||
for task in cfg.tasks:
|
||||
_cfg = deepcopy(cfg)
|
||||
|
||||
@@ -44,7 +44,7 @@ def evaluate(cfg: dict):
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold']))
|
||||
print(colored(f'Model size: {cfg.model_size}', 'blue', attrs=['bold']))
|
||||
print(colored(f'Model size: {cfg.get("model_size", "default")}', 'blue', attrs=['bold']))
|
||||
print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold']))
|
||||
if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint):
|
||||
print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold']))
|
||||
|
||||
@@ -16,8 +16,8 @@ class TDMPC2:
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.device = torch.device('cuda')
|
||||
self.model = WorldModel(cfg).to(self.device)
|
||||
self.device = torch.device(cfg.rank)
|
||||
self.model = WorldModel(cfg)
|
||||
self.optim = torch.optim.Adam([
|
||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||
{'params': self.model._dynamics.parameters()},
|
||||
@@ -30,7 +30,7 @@ class TDMPC2:
|
||||
self.scale = RunningScale(cfg)
|
||||
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
|
||||
self.discount = torch.tensor(
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda'
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=cfg.rank
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
|
||||
def _get_discount(self, episode_length):
|
||||
|
||||
@@ -14,14 +14,28 @@ from common.buffer import Buffer
|
||||
from envs import make_env
|
||||
from tdmpc2 import TDMPC2
|
||||
from trainer.offline_trainer import OfflineTrainer
|
||||
from trainer.online_trainer import OnlineTrainer
|
||||
from common.logger import Logger
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
@hydra.main(config_name='config', config_path='.')
|
||||
def train(cfg: dict):
|
||||
def setup(rank, world_size):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
# initialize the process group
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
rank=rank,
|
||||
world_size=world_size
|
||||
)
|
||||
|
||||
|
||||
def cleanup():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def train(rank: int, cfg: dict):
|
||||
"""
|
||||
Script for training single-task / multi-task TD-MPC2 agents.
|
||||
|
||||
@@ -40,14 +54,11 @@ def train(cfg: dict):
|
||||
$ python train.py task=dog-run steps=7000000
|
||||
```
|
||||
"""
|
||||
assert torch.cuda.is_available()
|
||||
assert cfg.steps > 0, 'Must train for at least 1 step.'
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||
setup(rank, cfg.world_size)
|
||||
set_seed(cfg.seed + rank)
|
||||
cfg.rank = rank
|
||||
|
||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||
trainer = trainer_cls(
|
||||
trainer = OfflineTrainer(
|
||||
cfg=cfg,
|
||||
env=make_env(cfg),
|
||||
agent=TDMPC2(cfg),
|
||||
@@ -55,8 +66,26 @@ def train(cfg: dict):
|
||||
logger=Logger(cfg),
|
||||
)
|
||||
trainer.train()
|
||||
print('\nTraining completed successfully')
|
||||
if cfg.rank == 0:
|
||||
print('\nTraining completed successfully')
|
||||
cleanup()
|
||||
|
||||
|
||||
@hydra.main(config_name='config', config_path='.')
|
||||
def launch(cfg: dict):
|
||||
assert torch.cuda.is_available()
|
||||
assert cfg.world_size > 0, 'Must train with at least 1 GPU.'
|
||||
assert cfg.task in {'mt30', 'mt80'}, 'Distributed training is only supported for multi-task experiments.'
|
||||
assert cfg.steps > 0, 'Must train for at least 1 step.'
|
||||
cfg = parse_cfg(cfg)
|
||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||
torch.multiprocessing.spawn(
|
||||
train,
|
||||
args=(cfg,),
|
||||
nprocs=cfg.world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
||||
launch()
|
||||
|
||||
@@ -7,8 +7,9 @@ class Trainer:
|
||||
self.agent = agent
|
||||
self.buffer = buffer
|
||||
self.logger = logger
|
||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
||||
print('Architecture:', self.agent.model)
|
||||
if cfg.rank == 0:
|
||||
print('Architecture:', self.agent.model)
|
||||
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate a TD-MPC2 agent."""
|
||||
|
||||
@@ -50,12 +50,21 @@ class OfflineTrainer(Trainer):
|
||||
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
||||
fps = sorted(glob(str(fp)))
|
||||
assert len(fps) > 0, f'No data found at {fp}'
|
||||
print(f'Found {len(fps)} files in {fp}')
|
||||
|
||||
if self.cfg.rank == 0:
|
||||
print(f'Found {len(fps)} files in {fp}')
|
||||
|
||||
# Distribute data across processes
|
||||
assert len(fps) >= self.cfg.world_size, \
|
||||
f'World size {self.cfg.world_size} cannot be greater than number of data chunks {len(fps)}'
|
||||
fps = fps[self.cfg.rank::self.cfg.world_size]
|
||||
print(f'Process {self.cfg.rank} has {len(fps)} files')
|
||||
assert len(fps) > 0, f'No data assigned to process {self.cfg.rank}'
|
||||
|
||||
# Create buffer for sampling
|
||||
_cfg = deepcopy(self.cfg)
|
||||
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
|
||||
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
|
||||
_cfg.buffer_size //= self.cfg.world_size
|
||||
_cfg.steps = _cfg.buffer_size
|
||||
self.buffer = Buffer(_cfg)
|
||||
for fp in tqdm(fps, desc='Loading data'):
|
||||
@@ -65,10 +74,12 @@ class OfflineTrainer(Trainer):
|
||||
f'please double-check your config.'
|
||||
for i in range(len(td)):
|
||||
self.buffer.add(td[i])
|
||||
assert self.buffer.num_eps == self.buffer.capacity, \
|
||||
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.'
|
||||
if self.buffer.num_transitions > self.buffer.capacity:
|
||||
print(f'Buffer has {self.buffer.num_transitions} transitions,' \
|
||||
f'expected maximum {self.buffer.capacity} transitions in process {self.cfg.rank}.')
|
||||
|
||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||
if self.cfg.rank == 0:
|
||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||
metrics = {}
|
||||
for i in range(self.cfg.steps):
|
||||
|
||||
@@ -76,7 +87,7 @@ class OfflineTrainer(Trainer):
|
||||
train_metrics = self.agent.update(self.buffer)
|
||||
|
||||
# Evaluate agent periodically
|
||||
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
|
||||
if self.cfg.rank == 0 and (i % self.cfg.eval_freq == 0 or i % 10_000 == 0):
|
||||
metrics = {
|
||||
'iteration': i,
|
||||
'total_time': time() - self._start_time,
|
||||
@@ -89,4 +100,5 @@ class OfflineTrainer(Trainer):
|
||||
self.logger.save_agent(self.agent, identifier=f'{i}')
|
||||
self.logger.log(metrics, 'pretrain')
|
||||
|
||||
self.logger.finish(self.agent)
|
||||
if self.cfg.rank == 0:
|
||||
self.logger.finish(self.agent)
|
||||
|
||||
Reference in New Issue
Block a user