Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bbc14ebab | ||
|
|
7992fa193e | ||
|
|
7ec6bc83a8 | ||
|
|
38b31a5d72 | ||
|
|
7942e9082b | ||
|
|
eece80123d | ||
|
|
38f853efc4 | ||
|
|
62be41ab58 | ||
|
|
c95b755655 | ||
|
|
81eb17068e | ||
|
|
add30b5a74 | ||
|
|
0a914570dc | ||
|
|
55bde9745f | ||
|
|
5ced6dfeb4 | ||
|
|
dddc226d25 | ||
|
|
ae4238946f | ||
|
|
a19f91c0b5 | ||
|
|
e452ca7539 | ||
|
|
db1865334e | ||
|
|
804f9b3949 | ||
|
|
66f8c21f58 | ||
|
|
9cac7c5775 | ||
|
|
df8a465c8e | ||
|
|
2e27fbb6f4 | ||
|
|
6117bc427d | ||
|
|
32fc2bdf93 | ||
|
|
3789fcd5b8 | ||
|
|
d51feb0e9f | ||
|
|
2dc668ecaf | ||
|
|
39be86fc52 | ||
|
|
f0af4b6b27 | ||
|
|
a2d9d0c8ff | ||
|
|
ab43880945 | ||
|
|
ff02f41e73 | ||
|
|
e86c343a67 | ||
|
|
cc62c4c9ce | ||
|
|
fabf01a5ec | ||
|
|
26c72119cd |
25
README.md
25
README.md
@@ -12,9 +12,9 @@ Official implementation of
|
|||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
**Announcement: training just got ~4.5x faster!**
|
**Announcement (Apr 2025): support for episodic tasks!**
|
||||||
|
|
||||||
Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. `compile=true` is available in state-based online RL at the moment, and we expect to roll out support across all settings in the coming months. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility!
|
We have added support for episodic RL (tasks with terminations) in the latest release. This functionality can be enabled with `episodic=true` but remains disabled by default to ensure reproducibility of results across releases.
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
@@ -36,19 +36,18 @@ You will need a machine with a GPU and at least 12 GB of RAM for single-task onl
|
|||||||
We provide a `Dockerfile` for easy installation. You can build the docker image by running
|
We provide a `Dockerfile` for easy installation. You can build the docker image by running
|
||||||
|
|
||||||
```
|
```
|
||||||
cd docker && docker build . -t <user>/tdmpc2:1.0.0
|
cd docker && docker build . -t <user>/tdmpc2:1.0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
This docker image contains all dependencies needed for running DMControl, Meta-World, and ManiSkill2 experiments.
|
This docker image contains all dependencies needed for running DMControl. We also provide a pre-built docker image [here](https://hub.docker.com/repository/docker/nicklashansen/tdmpc2/tags/1.0.1/sha256-b07d4e04d4b28ffd9a63ac18ec1541950e874bb51d276c7d09b36135f170dd93).
|
||||||
|
|
||||||
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running the following command:
|
If you prefer to use `conda` rather than docker, start by running the following command:
|
||||||
|
|
||||||
```
|
```
|
||||||
conda env create -f docker/environment.yaml
|
conda env create -f docker/environment.yaml
|
||||||
pip install gym==0.21.0
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The `environment.yaml` file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in `environment.yaml`.
|
The `docker/environment.yaml` file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in `docker/environment.yaml`.
|
||||||
|
|
||||||
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
|
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
|
||||||
|
|
||||||
@@ -62,19 +61,19 @@ which downloads assets to `./data`. You may move these assets to any location. T
|
|||||||
export MS2_ASSET_DIR=<path>/<to>/<data>
|
export MS2_ASSET_DIR=<path>/<to>/<data>
|
||||||
```
|
```
|
||||||
|
|
||||||
and restart your terminal. Meta-World additionally requires MuJoCo 2.1.0. We host the unrestricted MuJoCo 2.1.0 license (courtesy of Google DeepMind) at [https://www.tdmpc2.com/files/mjkey.txt](https://www.tdmpc2.com/files/mjkey.txt). You can download the license by running
|
and restart your terminal. Note that Meta-World requires MuJoCo 2.1.0 and `gym==0.21.0` which is becoming increasingly difficult to install. We host the unrestricted MuJoCo 2.1.0 license (courtesy of Google DeepMind) at [https://www.tdmpc2.com/files/mjkey.txt](https://www.tdmpc2.com/files/mjkey.txt). You can download the license by running
|
||||||
|
|
||||||
```
|
```
|
||||||
wget https://www.tdmpc2.com/files/mjkey.txt -O ~/.mujoco/mjkey.txt
|
wget https://www.tdmpc2.com/files/mjkey.txt -O ~/.mujoco/mjkey.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
See `docker/Dockerfile` for installation instructions if you do not already have MuJoCo 2.1.0 installed. MyoSuite requires `gym==0.13.0` which is incompatible with Meta-World and ManiSkill2. Install separately with `pip install myosuite` if desired. Depending on your existing system packages, you may need to install other dependencies. See `docker/Dockerfile` for a list of recommended system packages.
|
Depending on your existing system packages, you may need to install other dependencies. See `docker/Dockerfile` for a list of recommended system packages.
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
## Supported tasks
|
## Supported tasks
|
||||||
|
|
||||||
This codebase currently supports **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite**. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain:
|
This codebase provides support for all **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite** used in our paper. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain:
|
||||||
|
|
||||||
| domain | task
|
| domain | task
|
||||||
| --- | --- |
|
| --- | --- |
|
||||||
@@ -87,9 +86,9 @@ This codebase currently supports **104** continuous control tasks from **DMContr
|
|||||||
| myosuite | myo-key-turn
|
| myosuite | myo-key-turn
|
||||||
| myosuite | myo-key-turn-hard
|
| myosuite | myo-key-turn-hard
|
||||||
|
|
||||||
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively.
|
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. While you generally do not need to access the underlying task IDs or embeddings during training or evaluation of our multi-task models, the mapping from task name to task embedding used in our work can be found [here](https://github.com/nicklashansen/tdmpc2/blob/7ec6bc83a82a5188ca3faddc59aea83f430ab570/tdmpc2/common/__init__.py#L26). As of April 2025, our codebase also provides basic support for other MuJoCo/Box2d Gymnasium tasks; refer to the `envs` directory for a list of tasks. It should be relatively straightforward to add support for custom tasks by following the examples in `envs`.
|
||||||
|
|
||||||
**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies.
|
**Note:** we also provide support for image observations in the DMControl tasks. Use argument `obs=rgb` if you wish to train visual policies.
|
||||||
|
|
||||||
|
|
||||||
## Example usage
|
## Example usage
|
||||||
@@ -121,8 +120,6 @@ $ python train.py task=walker-walk obs=rgb
|
|||||||
|
|
||||||
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
|
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
|
||||||
|
|
||||||
**As of Jan 7, 2024 the TD-MPC2 codebase also supports multi-GPU training for multi-task offline RL experiments**; use branch `distributed` and argument `world_size=N` to train on `N` GPUs. We cannot guarantee that distributed training will yield the same results, but they appear to be similar based on our limited testing.
|
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|||||||
@@ -4,14 +4,14 @@
|
|||||||
# https://www.tdmpc2.com #
|
# https://www.tdmpc2.com #
|
||||||
# -------------------------------------- #
|
# -------------------------------------- #
|
||||||
# Build instructions: #
|
# Build instructions: #
|
||||||
# docker build . -t <user>/tdmpc2:1.0.0 #
|
# docker build . -t <user>/tdmpc2:1.0.1 #
|
||||||
# docker push <user>/tdmpc2:1.0.0 #
|
# docker push <user>/tdmpc2:1.0.1 #
|
||||||
# -------------------------------------- #
|
# -------------------------------------- #
|
||||||
# Run: #
|
# Run: #
|
||||||
# docker run -i \ #
|
# docker run -i \ #
|
||||||
# -v <path>/<to>/tdmpc2:/tdmpc2 \ #
|
# -v <path>/<to>/tdmpc2:/tdmpc2 \ #
|
||||||
# --gpus all \ #
|
# --gpus all \ #
|
||||||
# -t <user>/tdmpc2:1.0.0 \ #
|
# -t <user>/tdmpc2:1.0.1 \ #
|
||||||
# /bin/bash #
|
# /bin/bash #
|
||||||
##########################################
|
##########################################
|
||||||
|
|
||||||
@@ -42,34 +42,8 @@ RUN conda update conda && \
|
|||||||
conda init
|
conda init
|
||||||
SHELL ["/bin/bash", "-c"]
|
SHELL ["/bin/bash", "-c"]
|
||||||
RUN echo "cd /root" >> /root/.bashrc
|
RUN echo "cd /root" >> /root/.bashrc
|
||||||
|
# image does not include metaworld, maniskill, myosuite
|
||||||
# mujoco 2.1.0
|
# these can be installed separately; see environment.yaml for details
|
||||||
ENV MUJOCO_GL egl
|
|
||||||
ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
|
|
||||||
RUN mkdir -p /root/.mujoco && \
|
|
||||||
wget https://www.tdmpc2.com/files/mjkey.txt && \
|
|
||||||
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \
|
|
||||||
tar -xzf mujoco210-linux-x86_64.tar.gz && \
|
|
||||||
rm mujoco210-linux-x86_64.tar.gz && \
|
|
||||||
mv mujoco210 /root/.mujoco/mujoco210 && \
|
|
||||||
mv mjkey.txt /root/.mujoco/mjkey.txt && \
|
|
||||||
find /root/.mujoco -uid 421709 -exec chown root:root {} \; && \
|
|
||||||
python -c "import mujoco_py"
|
|
||||||
|
|
||||||
# gym
|
|
||||||
RUN pip install gym==0.21.0
|
|
||||||
|
|
||||||
# metaworld
|
|
||||||
RUN pip install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
|
|
||||||
|
|
||||||
# maniskill2
|
|
||||||
ENV MS2_ASSET_DIR /root/data
|
|
||||||
RUN pip install mani-skill2==0.4.1 && \
|
|
||||||
cd /root && \
|
|
||||||
python -m mani_skill2.utils.download_asset all -y
|
|
||||||
|
|
||||||
# myosuite (conflicts with meta-world / mani-skill2)
|
|
||||||
# RUN pip install myosuite
|
|
||||||
|
|
||||||
# success!
|
# success!
|
||||||
RUN echo "Successfully built TD-MPC2 Docker image!"
|
RUN echo "Successfully built TD-MPC2 Docker image!"
|
||||||
|
|||||||
@@ -13,10 +13,9 @@ dependencies:
|
|||||||
- pytorch-cuda=12.4
|
- pytorch-cuda=12.4
|
||||||
- torchvision=0.15.2
|
- torchvision=0.15.2
|
||||||
- pip:
|
- pip:
|
||||||
- absl-py==2.1.0
|
- dm-control==1.0.16
|
||||||
- "cython<3"
|
|
||||||
- dm-control==1.0.8
|
|
||||||
- glfw==2.7.0
|
- glfw==2.7.0
|
||||||
|
- gymnasium==0.29.1
|
||||||
- ffmpeg==1.4
|
- ffmpeg==1.4
|
||||||
- imageio==2.34.1
|
- imageio==2.34.1
|
||||||
- imageio-ffmpeg==0.4.9
|
- imageio-ffmpeg==0.4.9
|
||||||
@@ -24,24 +23,25 @@ dependencies:
|
|||||||
- hydra-core==1.3.2
|
- hydra-core==1.3.2
|
||||||
- hydra-submitit-launcher==1.2.0
|
- hydra-submitit-launcher==1.2.0
|
||||||
- submitit==1.5.1
|
- submitit==1.5.1
|
||||||
- setuptools==65.5.0
|
|
||||||
- patchelf==0.17.2.1
|
|
||||||
- omegaconf==2.3.0
|
- omegaconf==2.3.0
|
||||||
- moviepy==1.0.3
|
- moviepy==1.0.3
|
||||||
- mujoco==2.3.1
|
- mujoco==3.1.2
|
||||||
- mujoco-py==2.1.2.14
|
|
||||||
- numpy==1.24.4
|
- numpy==1.24.4
|
||||||
- tensordict-nightly==2024.11.14
|
- tensordict-nightly==2025.1.1
|
||||||
- torchrl-nightly==2024.11.14
|
- torchrl-nightly==2025.1.1
|
||||||
- kornia==0.7.2
|
- kornia==0.7.2
|
||||||
- termcolor==2.4.0
|
- termcolor==2.4.0
|
||||||
- tqdm==4.66.4
|
- tqdm==4.66.4
|
||||||
- pandas==2.0.3
|
- pandas==2.0.3
|
||||||
- wandb==0.17.4
|
- wandb==0.17.4
|
||||||
- wheel==0.38.0
|
|
||||||
####################
|
####################
|
||||||
# Gym:
|
# Gym:
|
||||||
# (unmaintained but required for maniskill2/meta-world)
|
# (unmaintained but required for maniskill2/meta-world)
|
||||||
|
# - "cython<3"
|
||||||
|
# - wheel==0.38.0
|
||||||
|
# - setuptools==65.5.0
|
||||||
|
# - mujoco==2.3.1
|
||||||
|
# - mujoco-py==2.1.2.14
|
||||||
# - gym==0.21.0
|
# - gym==0.21.0
|
||||||
####################
|
####################
|
||||||
# ManiSkill2:
|
# ManiSkill2:
|
||||||
@@ -55,3 +55,7 @@ dependencies:
|
|||||||
# MyoSuite:
|
# MyoSuite:
|
||||||
# - myosuite
|
# - myosuite
|
||||||
####################
|
####################
|
||||||
|
# Classic MuJoCo/Box2d:
|
||||||
|
# - swig
|
||||||
|
# - gymnasium[box2d]
|
||||||
|
####################
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class Buffer():
|
|||||||
traj_key='episode',
|
traj_key='episode',
|
||||||
truncated_key=None,
|
truncated_key=None,
|
||||||
strict_length=True,
|
strict_length=True,
|
||||||
|
cache_values=cfg.multitask,
|
||||||
)
|
)
|
||||||
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
self._batch_size = cfg.batch_size * (cfg.horizon+1)
|
||||||
self._num_eps = 0
|
self._num_eps = 0
|
||||||
@@ -65,19 +66,20 @@ class Buffer():
|
|||||||
LazyTensorStorage(self._capacity, device=self._storage_device)
|
LazyTensorStorage(self._capacity, device=self._storage_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_batch(self, td):
|
def load(self, td):
|
||||||
"""
|
"""
|
||||||
Prepare a sampled batch for training (post-processing).
|
Load a batch of episodes into the buffer. This is useful for loading data from disk,
|
||||||
Expects `td` to be a TensorDict with batch size TxB.
|
and is more efficient than adding episodes one by one.
|
||||||
"""
|
"""
|
||||||
td = td.select("obs", "action", "reward", "task", strict=False).to(self._device, non_blocking=True)
|
num_new_eps = len(td)
|
||||||
obs = td.get('obs').contiguous()
|
episode_idx = torch.arange(self._num_eps, self._num_eps+num_new_eps, dtype=torch.int64)
|
||||||
action = td.get('action')[1:].contiguous()
|
td['episode'] = episode_idx.unsqueeze(-1).expand(-1, td['reward'].shape[1])
|
||||||
reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
|
if self._num_eps == 0:
|
||||||
task = td.get('task', None)
|
self._buffer = self._init(td[0])
|
||||||
if task is not None:
|
td = td.reshape(td.shape[0]*td.shape[1])
|
||||||
task = task[0].contiguous()
|
self._buffer.extend(td)
|
||||||
return obs, action, reward, task
|
self._num_eps += num_new_eps
|
||||||
|
return self._num_eps
|
||||||
|
|
||||||
def add(self, td):
|
def add(self, td):
|
||||||
"""Add an episode to the buffer."""
|
"""Add an episode to the buffer."""
|
||||||
@@ -88,6 +90,25 @@ class Buffer():
|
|||||||
self._num_eps += 1
|
self._num_eps += 1
|
||||||
return self._num_eps
|
return self._num_eps
|
||||||
|
|
||||||
|
def _prepare_batch(self, td):
|
||||||
|
"""
|
||||||
|
Prepare a sampled batch for training (post-processing).
|
||||||
|
Expects `td` to be a TensorDict with batch size TxB.
|
||||||
|
"""
|
||||||
|
td = td.select("obs", "action", "reward", "terminated", "task", strict=False).to(self._device, non_blocking=True)
|
||||||
|
obs = td.get('obs').contiguous()
|
||||||
|
action = td.get('action')[1:].contiguous()
|
||||||
|
reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
|
||||||
|
terminated = td.get('terminated', None)
|
||||||
|
if terminated is not None:
|
||||||
|
terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
|
||||||
|
else:
|
||||||
|
terminated = torch.zeros_like(reward)
|
||||||
|
task = td.get('task', None)
|
||||||
|
if task is not None:
|
||||||
|
task = task[0].contiguous()
|
||||||
|
return obs, action, reward, terminated, task
|
||||||
|
|
||||||
def sample(self):
|
def sample(self):
|
||||||
"""Sample a batch of subsequences from the buffer."""
|
"""Sample a batch of subsequences from the buffer."""
|
||||||
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
td = self._buffer.sample().view(-1, self.cfg.horizon+1).permute(1, 0)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|||||||
from tensordict import from_modules
|
from tensordict import from_modules
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
class Ensemble(nn.Module):
|
class Ensemble(nn.Module):
|
||||||
"""
|
"""
|
||||||
Vectorized ensemble of modules.
|
Vectorized ensemble of modules.
|
||||||
@@ -15,7 +16,11 @@ class Ensemble(nn.Module):
|
|||||||
self.params = from_modules(*modules, as_module=True)
|
self.params = from_modules(*modules, as_module=True)
|
||||||
with self.params[0].data.to("meta").to_module(modules[0]):
|
with self.params[0].data.to("meta").to_module(modules[0]):
|
||||||
self.module = deepcopy(modules[0])
|
self.module = deepcopy(modules[0])
|
||||||
self._repr = str(modules)
|
self._repr = str(modules[0])
|
||||||
|
self._n = len(modules)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._n
|
||||||
|
|
||||||
def _call(self, params, *args, **kwargs):
|
def _call(self, params, *args, **kwargs):
|
||||||
with params.to_module(self.module):
|
with params.to_module(self.module):
|
||||||
@@ -25,7 +30,7 @@ class Ensemble(nn.Module):
|
|||||||
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
|
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return 'Vectorized ' + self._repr
|
return f'Vectorized {len(self)}x ' + self._repr
|
||||||
|
|
||||||
|
|
||||||
class ShiftAug(nn.Module):
|
class ShiftAug(nn.Module):
|
||||||
@@ -157,3 +162,60 @@ def enc(cfg, out={}):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Encoder for observation type {k} not implemented.")
|
raise NotImplementedError(f"Encoder for observation type {k} not implemented.")
|
||||||
return nn.ModuleDict(out)
|
return nn.ModuleDict(out)
|
||||||
|
|
||||||
|
|
||||||
|
def api_model_conversion(target_state_dict, source_state_dict):
|
||||||
|
"""
|
||||||
|
Converts a checkpoint from our old API to the new torch.compile compatible API.
|
||||||
|
"""
|
||||||
|
# check whether checkpoint is already in the new format
|
||||||
|
if "_detach_Qs_params.0.weight" in source_state_dict:
|
||||||
|
return source_state_dict
|
||||||
|
|
||||||
|
name_map = ['weight', 'bias', 'ln.weight', 'ln.bias']
|
||||||
|
new_state_dict = dict()
|
||||||
|
|
||||||
|
# rename keys
|
||||||
|
for key, val in list(source_state_dict.items()):
|
||||||
|
if key.startswith('_Qs.'):
|
||||||
|
num = key[len('_Qs.params.'):]
|
||||||
|
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
|
||||||
|
new_total_key = "_Qs.params." + new_key
|
||||||
|
del source_state_dict[key]
|
||||||
|
new_state_dict[new_total_key] = val
|
||||||
|
new_total_key = "_detach_Qs_params." + new_key
|
||||||
|
new_state_dict[new_total_key] = val
|
||||||
|
elif key.startswith('_target_Qs.'):
|
||||||
|
num = key[len('_target_Qs.params.'):]
|
||||||
|
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
|
||||||
|
new_total_key = "_target_Qs_params." + new_key
|
||||||
|
del source_state_dict[key]
|
||||||
|
new_state_dict[new_total_key] = val
|
||||||
|
|
||||||
|
# add batch_size and device from target_state_dict to new_state_dict
|
||||||
|
for prefix in ('_Qs.', '_detach_Qs_', '_target_Qs_'):
|
||||||
|
for key in ('__batch_size', '__device'):
|
||||||
|
new_key = prefix + 'params.' + key
|
||||||
|
new_state_dict[new_key] = target_state_dict[new_key]
|
||||||
|
|
||||||
|
# check that every key in new_state_dict is in target_state_dict
|
||||||
|
for key in new_state_dict.keys():
|
||||||
|
assert key in target_state_dict, f"key {key} not in target_state_dict"
|
||||||
|
# check that all Qs keys in target_state_dict are in new_state_dict
|
||||||
|
for key in target_state_dict.keys():
|
||||||
|
if 'Qs' in key:
|
||||||
|
assert key in new_state_dict, f"key {key} not in new_state_dict"
|
||||||
|
# check that source_state_dict contains no Qs keys
|
||||||
|
for key in source_state_dict.keys():
|
||||||
|
assert 'Qs' not in key, f"key {key} contains 'Qs'"
|
||||||
|
|
||||||
|
# copy log_std_min and log_std_max from target_state_dict to new_state_dict
|
||||||
|
new_state_dict['log_std_min'] = target_state_dict['log_std_min']
|
||||||
|
new_state_dict['log_std_dif'] = target_state_dict['log_std_dif']
|
||||||
|
if '_action_masks' in target_state_dict:
|
||||||
|
new_state_dict['_action_masks'] = target_state_dict['_action_masks']
|
||||||
|
|
||||||
|
# copy new_state_dict to source_state_dict
|
||||||
|
source_state_dict.update(new_state_dict)
|
||||||
|
|
||||||
|
return source_state_dict
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ CONSOLE_FORMAT = [
|
|||||||
("step", "I", "int"),
|
("step", "I", "int"),
|
||||||
("episode_reward", "R", "float"),
|
("episode_reward", "R", "float"),
|
||||||
("episode_success", "S", "float"),
|
("episode_success", "S", "float"),
|
||||||
("total_time", "T", "time"),
|
("elapsed_time", "T", "time"),
|
||||||
]
|
]
|
||||||
|
|
||||||
CAT_TO_COLOR = {
|
CAT_TO_COLOR = {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
|
|
||||||
def soft_ce(pred, target, cfg):
|
def soft_ce(pred, target, cfg):
|
||||||
@@ -13,35 +14,32 @@ def log_std(x, low, dif):
|
|||||||
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||||
|
|
||||||
|
|
||||||
def _gaussian_residual(eps, log_std):
|
def gaussian_logprob(eps, log_std):
|
||||||
return -0.5 * eps.pow(2) - log_std
|
|
||||||
|
|
||||||
|
|
||||||
def _gaussian_logprob(residual):
|
|
||||||
log2pi = 1.8378770351409912
|
|
||||||
return residual - 0.5 * log2pi
|
|
||||||
|
|
||||||
|
|
||||||
def gaussian_logprob(eps, log_std, size=None):
|
|
||||||
"""Compute Gaussian log probability."""
|
"""Compute Gaussian log probability."""
|
||||||
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
residual = -0.5 * eps.pow(2) - log_std
|
||||||
if size is None:
|
log_prob = residual - 0.9189385175704956
|
||||||
size = eps.shape[-1]
|
return log_prob.sum(-1, keepdim=True)
|
||||||
return _gaussian_logprob(residual) * size
|
|
||||||
|
|
||||||
|
|
||||||
def _squash(pi):
|
|
||||||
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def squash(mu, pi, log_pi):
|
def squash(mu, pi, log_pi):
|
||||||
"""Apply squashing function."""
|
"""Apply squashing function."""
|
||||||
mu = torch.tanh(mu)
|
mu = torch.tanh(mu)
|
||||||
pi = torch.tanh(pi)
|
pi = torch.tanh(pi)
|
||||||
log_pi -= _squash(pi).sum(-1, keepdim=True)
|
squashed_pi = torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||||
|
log_pi = log_pi - squashed_pi.sum(-1, keepdim=True)
|
||||||
return mu, pi, log_pi
|
return mu, pi, log_pi
|
||||||
|
|
||||||
|
|
||||||
|
def int_to_one_hot(x, num_classes):
|
||||||
|
"""
|
||||||
|
Converts an integer tensor to a one-hot tensor.
|
||||||
|
Supports batched inputs.
|
||||||
|
"""
|
||||||
|
one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
|
||||||
|
one_hot.scatter_(-1, x.unsqueeze(-1), 1)
|
||||||
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
def symlog(x):
|
def symlog(x):
|
||||||
"""
|
"""
|
||||||
Symmetric logarithmic function.
|
Symmetric logarithmic function.
|
||||||
@@ -87,11 +85,26 @@ def two_hot_inv(x, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
|
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
|
||||||
|
"""Sample from the Gumbel-Softmax distribution."""
|
||||||
logits = p.log()
|
logits = p.log()
|
||||||
# Generate Gumbel noise
|
|
||||||
gumbels = (
|
gumbels = (
|
||||||
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
||||||
) # ~Gumbel(0,1)
|
) # ~Gumbel(0,1)
|
||||||
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
|
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
|
||||||
y_soft = gumbels.softmax(dim)
|
y_soft = gumbels.softmax(dim)
|
||||||
return y_soft.argmax(-1)
|
return y_soft.argmax(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def termination_statistics(pred, target, eps=1e-9):
|
||||||
|
"""Compute episode termination statistics."""
|
||||||
|
pred = pred.squeeze(-1)
|
||||||
|
target = target.squeeze(-1)
|
||||||
|
rate = target.sum() / len(target)
|
||||||
|
tp = ((pred > 0.5) & (target == 1)).sum()
|
||||||
|
fn = ((pred <= 0.5) & (target == 1)).sum()
|
||||||
|
fp = ((pred > 0.5) & (target == 0)).sum()
|
||||||
|
recall = tp / (tp + fn + eps)
|
||||||
|
precision = tp / (tp + fp + eps)
|
||||||
|
f1 = 2 * (precision * recall) / (precision + recall + eps)
|
||||||
|
return TensorDict({'termination_rate': rate,
|
||||||
|
'termination_f1': f1})
|
||||||
|
|||||||
@@ -77,9 +77,4 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
|||||||
cfg.task_dim = 0
|
cfg.task_dim = 0
|
||||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||||
|
|
||||||
# Check torch.compile compatibility
|
|
||||||
if cfg.get('compile', False):
|
|
||||||
assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.'
|
|
||||||
assert not cfg.multitask, 'torch.compile does not support multitask training at the moment.'
|
|
||||||
|
|
||||||
return cfg_to_dataclass(cfg)
|
return cfg_to_dataclass(cfg)
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn import Buffer
|
from torch.nn import Buffer
|
||||||
|
|
||||||
|
|
||||||
class RunningScale(torch.nn.Module):
|
class RunningScale(torch.nn.Module):
|
||||||
"""Running trimmed scale estimator."""
|
"""Running trimmed scale estimator."""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
|
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda:0')))
|
||||||
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')))
|
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda:0')))
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return dict(value=self.value, percentiles=self._percentiles)
|
return dict(value=self.value, percentiles=self._percentiles)
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from common import layers, math, init
|
from common import layers, math, init
|
||||||
|
from tensordict import TensorDict
|
||||||
from tensordict.nn import TensorDictParams
|
from tensordict.nn import TensorDictParams
|
||||||
|
|
||||||
|
|
||||||
class WorldModel(nn.Module):
|
class WorldModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
TD-MPC2 implicit world model architecture.
|
TD-MPC2 implicit world model architecture.
|
||||||
@@ -23,6 +25,7 @@ class WorldModel(nn.Module):
|
|||||||
self._encoder = layers.enc(cfg)
|
self._encoder = layers.enc(cfg)
|
||||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||||
|
self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) if cfg.episodic else None
|
||||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||||
self.apply(init.weight_init)
|
self.apply(init.weight_init)
|
||||||
@@ -43,13 +46,18 @@ class WorldModel(nn.Module):
|
|||||||
self._target_Qs = deepcopy(self._Qs)
|
self._target_Qs = deepcopy(self._Qs)
|
||||||
|
|
||||||
# Assign params to modules
|
# Assign params to modules
|
||||||
self._detach_Qs.params = self._detach_Qs_params
|
# We do this strange assignment to avoid having duplicated tensors in the state-dict -- working on a better API for this
|
||||||
self._target_Qs.params = self._target_Qs_params
|
delattr(self._detach_Qs, "params")
|
||||||
|
self._detach_Qs.__dict__["params"] = self._detach_Qs_params
|
||||||
|
delattr(self._target_Qs, "params")
|
||||||
|
self._target_Qs.__dict__["params"] = self._target_Qs_params
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
repr = 'TD-MPC2 World Model\n'
|
repr = 'TD-MPC2 World Model\n'
|
||||||
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions']
|
modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
|
||||||
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]):
|
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
|
||||||
|
if m == self._termination and not self.cfg.episodic:
|
||||||
|
continue
|
||||||
repr += f"{modules[i]}: {m}\n"
|
repr += f"{modules[i]}: {m}\n"
|
||||||
repr += "Learnable parameters: {:,}".format(self.total_params)
|
repr += "Learnable parameters: {:,}".format(self.total_params)
|
||||||
return repr
|
return repr
|
||||||
@@ -120,6 +128,18 @@ class WorldModel(nn.Module):
|
|||||||
z = self.task_emb(z, task)
|
z = self.task_emb(z, task)
|
||||||
z = torch.cat([z, a], dim=-1)
|
z = torch.cat([z, a], dim=-1)
|
||||||
return self._reward(z)
|
return self._reward(z)
|
||||||
|
|
||||||
|
def termination(self, z, task, unnormalized=False):
|
||||||
|
"""
|
||||||
|
Predicts termination signal.
|
||||||
|
"""
|
||||||
|
assert task is None
|
||||||
|
if self.cfg.multitask:
|
||||||
|
z = self.task_emb(z, task)
|
||||||
|
if unnormalized:
|
||||||
|
return self._termination(z)
|
||||||
|
return torch.sigmoid(self._termination(z))
|
||||||
|
|
||||||
|
|
||||||
def pi(self, z, task):
|
def pi(self, z, task):
|
||||||
"""
|
"""
|
||||||
@@ -131,23 +151,37 @@ class WorldModel(nn.Module):
|
|||||||
z = self.task_emb(z, task)
|
z = self.task_emb(z, task)
|
||||||
|
|
||||||
# Gaussian policy prior
|
# Gaussian policy prior
|
||||||
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
mean, log_std = self._pi(z).chunk(2, dim=-1)
|
||||||
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
|
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
|
||||||
eps = torch.randn_like(mu)
|
eps = torch.randn_like(mean)
|
||||||
|
|
||||||
if self.cfg.multitask: # Mask out unused action dimensions
|
if self.cfg.multitask: # Mask out unused action dimensions
|
||||||
mu = mu * self._action_masks[task]
|
mean = mean * self._action_masks[task]
|
||||||
log_std = log_std * self._action_masks[task]
|
log_std = log_std * self._action_masks[task]
|
||||||
eps = eps * self._action_masks[task]
|
eps = eps * self._action_masks[task]
|
||||||
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
|
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
|
||||||
else: # No masking
|
else: # No masking
|
||||||
action_dims = None
|
action_dims = None
|
||||||
|
|
||||||
log_pi = math.gaussian_logprob(eps, log_std, size=action_dims)
|
log_prob = math.gaussian_logprob(eps, log_std)
|
||||||
pi = mu + eps * log_std.exp()
|
|
||||||
mu, pi, log_pi = math.squash(mu, pi, log_pi)
|
|
||||||
|
|
||||||
return mu, pi, log_pi, log_std
|
# Scale log probability by action dimensions
|
||||||
|
size = eps.shape[-1] if action_dims is None else action_dims
|
||||||
|
scaled_log_prob = log_prob * size
|
||||||
|
|
||||||
|
# Reparameterization trick
|
||||||
|
action = mean + eps * log_std.exp()
|
||||||
|
mean, action, log_prob = math.squash(mean, action, log_prob)
|
||||||
|
|
||||||
|
entropy_scale = scaled_log_prob / (log_prob + 1e-8)
|
||||||
|
info = TensorDict({
|
||||||
|
"mean": mean,
|
||||||
|
"log_std": log_std,
|
||||||
|
"action_prob": 1.,
|
||||||
|
"entropy": -log_prob,
|
||||||
|
"scaled_entropy": -log_prob * entropy_scale,
|
||||||
|
})
|
||||||
|
return action, info
|
||||||
|
|
||||||
def Q(self, z, a, task, return_type='min', target=False, detach=False):
|
def Q(self, z, a, task, return_type='min', target=False, detach=False):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ defaults:
|
|||||||
# environment
|
# environment
|
||||||
task: dog-run
|
task: dog-run
|
||||||
obs: state
|
obs: state
|
||||||
|
episodic: false
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
checkpoint: ???
|
checkpoint: ???
|
||||||
@@ -15,6 +16,7 @@ steps: 10_000_000
|
|||||||
batch_size: 256
|
batch_size: 256
|
||||||
reward_coef: 0.1
|
reward_coef: 0.1
|
||||||
value_coef: 0.1
|
value_coef: 0.1
|
||||||
|
termination_coef: 1
|
||||||
consistency_coef: 20
|
consistency_coef: 20
|
||||||
rho: 0.5
|
rho: 0.5
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
@@ -69,6 +71,7 @@ enable_wandb: true
|
|||||||
save_csv: true
|
save_csv: true
|
||||||
|
|
||||||
# misc
|
# misc
|
||||||
|
compile: true
|
||||||
save_video: true
|
save_video: true
|
||||||
save_agent: true
|
save_agent: true
|
||||||
seed: 1
|
seed: 1
|
||||||
@@ -86,6 +89,3 @@ action_dims: ???
|
|||||||
episode_lengths: ???
|
episode_lengths: ???
|
||||||
seed_steps: ???
|
seed_steps: ???
|
||||||
bin_size: ???
|
bin_size: ???
|
||||||
|
|
||||||
# speedups
|
|
||||||
compile: False
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from envs.wrappers.multitask import MultitaskWrapper
|
from envs.wrappers.multitask import MultitaskWrapper
|
||||||
from envs.wrappers.pixels import PixelWrapper
|
|
||||||
from envs.wrappers.tensor import TensorWrapper
|
from envs.wrappers.tensor import TensorWrapper
|
||||||
|
|
||||||
def missing_dependencies(task):
|
def missing_dependencies(task):
|
||||||
@@ -26,6 +25,10 @@ try:
|
|||||||
from envs.myosuite import make_env as make_myosuite_env
|
from envs.myosuite import make_env as make_myosuite_env
|
||||||
except:
|
except:
|
||||||
make_myosuite_env = missing_dependencies
|
make_myosuite_env = missing_dependencies
|
||||||
|
try:
|
||||||
|
from envs.mujoco import make_env as make_mujoco_env
|
||||||
|
except:
|
||||||
|
make_mujoco_env = missing_dependencies
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
@@ -62,7 +65,7 @@ def make_env(cfg):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
env = None
|
env = None
|
||||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]:
|
||||||
try:
|
try:
|
||||||
env = fn(cfg)
|
env = fn(cfg)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -70,8 +73,6 @@ def make_env(cfg):
|
|||||||
if env is None:
|
if env is None:
|
||||||
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
||||||
env = TensorWrapper(env)
|
env = TensorWrapper(env)
|
||||||
if cfg.get('obs', 'state') == 'rgb':
|
|
||||||
env = PixelWrapper(cfg, env)
|
|
||||||
try: # Dict
|
try: # Dict
|
||||||
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
||||||
except: # Box
|
except: # Box
|
||||||
|
|||||||
@@ -1,181 +1,92 @@
|
|||||||
from collections import deque, defaultdict
|
from collections import defaultdict, deque
|
||||||
from typing import Any, NamedTuple
|
|
||||||
import dm_env
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
|
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
|
||||||
from dm_control import suite
|
from dm_control import suite
|
||||||
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
||||||
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
||||||
from dm_control.suite.wrappers import action_scale
|
from dm_control.suite.wrappers import action_scale
|
||||||
from dm_env import StepType, specs
|
|
||||||
import gym
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
|
|
||||||
class ExtendedTimeStep(NamedTuple):
|
def get_obs_shape(env):
|
||||||
step_type: Any
|
obs_shp = []
|
||||||
reward: Any
|
for v in env.observation_spec().values():
|
||||||
discount: Any
|
try:
|
||||||
observation: Any
|
shp = np.prod(v.shape)
|
||||||
action: Any
|
except:
|
||||||
|
shp = 1
|
||||||
def first(self):
|
obs_shp.append(shp)
|
||||||
return self.step_type == StepType.FIRST
|
return (int(np.sum(obs_shp)),)
|
||||||
|
|
||||||
def mid(self):
|
|
||||||
return self.step_type == StepType.MID
|
|
||||||
|
|
||||||
def last(self):
|
|
||||||
return self.step_type == StepType.LAST
|
|
||||||
|
|
||||||
|
|
||||||
class ActionRepeatWrapper(dm_env.Environment):
|
class DMControlWrapper:
|
||||||
def __init__(self, env, num_repeats):
|
def __init__(self, env, domain):
|
||||||
self._env = env
|
|
||||||
self._num_repeats = num_repeats
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
reward = 0.0
|
|
||||||
discount = 1.0
|
|
||||||
for i in range(self._num_repeats):
|
|
||||||
time_step = self._env.step(action)
|
|
||||||
reward += (time_step.reward or 0.0) * discount
|
|
||||||
discount *= time_step.discount
|
|
||||||
if time_step.last():
|
|
||||||
break
|
|
||||||
|
|
||||||
return time_step._replace(reward=reward, discount=discount)
|
|
||||||
|
|
||||||
def observation_spec(self):
|
|
||||||
return self._env.observation_spec()
|
|
||||||
|
|
||||||
def action_spec(self):
|
|
||||||
return self._env.action_spec()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return self._env.reset()
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
|
|
||||||
class ActionDTypeWrapper(dm_env.Environment):
|
|
||||||
def __init__(self, env, dtype):
|
|
||||||
self._env = env
|
|
||||||
wrapped_action_spec = env.action_spec()
|
|
||||||
self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
|
|
||||||
dtype,
|
|
||||||
wrapped_action_spec.minimum,
|
|
||||||
wrapped_action_spec.maximum,
|
|
||||||
'action')
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
action = action.astype(self._env.action_spec().dtype)
|
|
||||||
return self._env.step(action)
|
|
||||||
|
|
||||||
def observation_spec(self):
|
|
||||||
return self._env.observation_spec()
|
|
||||||
|
|
||||||
def action_spec(self):
|
|
||||||
return self._action_spec
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return self._env.reset()
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
|
|
||||||
class ExtendedTimeStepWrapper(dm_env.Environment):
|
|
||||||
def __init__(self, env):
|
|
||||||
self._env = env
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
time_step = self._env.reset()
|
|
||||||
return self._augment_time_step(time_step)
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
time_step = self._env.step(action)
|
|
||||||
return self._augment_time_step(time_step, action)
|
|
||||||
|
|
||||||
def _augment_time_step(self, time_step, action=None):
|
|
||||||
if action is None:
|
|
||||||
action_spec = self.action_spec()
|
|
||||||
action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
|
|
||||||
return ExtendedTimeStep(observation=time_step.observation,
|
|
||||||
step_type=time_step.step_type,
|
|
||||||
action=action,
|
|
||||||
reward=time_step.reward or 0.0,
|
|
||||||
discount=time_step.discount or 1.0)
|
|
||||||
|
|
||||||
def observation_spec(self):
|
|
||||||
return self._env.observation_spec()
|
|
||||||
|
|
||||||
def action_spec(self):
|
|
||||||
return self._env.action_spec()
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
return getattr(self._env, name)
|
|
||||||
|
|
||||||
|
|
||||||
class TimeStepToGymWrapper:
|
|
||||||
def __init__(self, env, domain, task):
|
|
||||||
obs_shp = []
|
|
||||||
for v in env.observation_spec().values():
|
|
||||||
try:
|
|
||||||
shp = np.prod(v.shape)
|
|
||||||
except:
|
|
||||||
shp = 1
|
|
||||||
obs_shp.append(shp)
|
|
||||||
obs_shp = (int(np.sum(obs_shp)),)
|
|
||||||
act_shp = env.action_spec().shape
|
|
||||||
self.observation_space = gym.spaces.Box(
|
|
||||||
low=np.full(
|
|
||||||
obs_shp,
|
|
||||||
-np.inf,
|
|
||||||
dtype=np.float32),
|
|
||||||
high=np.full(
|
|
||||||
obs_shp,
|
|
||||||
np.inf,
|
|
||||||
dtype=np.float32),
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
self.action_space = gym.spaces.Box(
|
|
||||||
low=np.full(act_shp, env.action_spec().minimum),
|
|
||||||
high=np.full(act_shp, env.action_spec().maximum),
|
|
||||||
dtype=env.action_spec().dtype)
|
|
||||||
self.env = env
|
self.env = env
|
||||||
self.domain = domain
|
self.camera_id = 2 if domain == 'quadruped' else 0
|
||||||
self.task = task
|
obs_shape = get_obs_shape(env)
|
||||||
self.max_episode_steps = 500
|
action_shape = env.action_spec().shape
|
||||||
self.t = 0
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=np.full(obs_shape, -np.inf, dtype=np.float32),
|
||||||
|
high=np.full(obs_shape, np.inf, dtype=np.float32),
|
||||||
|
dtype=np.float32)
|
||||||
|
self.action_space = gym.spaces.Box(
|
||||||
|
low=np.full(action_shape, env.action_spec().minimum),
|
||||||
|
high=np.full(action_shape, env.action_spec().maximum),
|
||||||
|
dtype=env.action_spec().dtype)
|
||||||
|
self.action_spec_dtype = env.action_spec().dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self):
|
def unwrapped(self):
|
||||||
return self.env
|
return self.env
|
||||||
|
|
||||||
@property
|
|
||||||
def reward_range(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def metadata(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _obs_to_array(self, obs):
|
def _obs_to_array(self, obs):
|
||||||
return np.concatenate([v.flatten() for v in obs.values()])
|
return torch.from_numpy(
|
||||||
|
np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return self._obs_to_array(self.env.reset().observation)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
reward = 0
|
||||||
|
action = action.astype(self.action_spec_dtype)
|
||||||
|
for _ in range(2):
|
||||||
|
step = self.env.step(action)
|
||||||
|
reward += step.reward
|
||||||
|
return self._obs_to_array(step.observation), reward, False, defaultdict(float)
|
||||||
|
|
||||||
|
def render(self, width=384, height=384, camera_id=None):
|
||||||
|
return self.env.physics.render(height, width, camera_id or self.camera_id)
|
||||||
|
|
||||||
|
|
||||||
|
class Pixels(gym.Wrapper):
|
||||||
|
def __init__(self, env, cfg, num_frames=3, size=64):
|
||||||
|
super().__init__(env)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.env = env
|
||||||
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=0, high=255, shape=(num_frames*3, size, size), dtype=np.uint8)
|
||||||
|
self._frames = deque([], maxlen=num_frames)
|
||||||
|
self._size = size
|
||||||
|
|
||||||
|
def _get_obs(self, is_reset=False):
|
||||||
|
frame = self.env.render(width=self._size, height=self._size).transpose(2, 0, 1)
|
||||||
|
num_frames = self._frames.maxlen if is_reset else 1
|
||||||
|
for _ in range(num_frames):
|
||||||
|
self._frames.append(frame)
|
||||||
|
return torch.from_numpy(np.concatenate(self._frames))
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.t = 0
|
self.env.reset()
|
||||||
return self._obs_to_array(self.env.reset().observation)
|
return self._get_obs(is_reset=True)
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
self.t += 1
|
|
||||||
time_step = self.env.step(action)
|
|
||||||
return self._obs_to_array(time_step.observation), time_step.reward, time_step.last() or self.t == self.max_episode_steps, defaultdict(float)
|
|
||||||
|
|
||||||
def render(self, mode='rgb_array', width=384, height=384, camera_id=0):
|
def step(self, action):
|
||||||
camera_id = dict(quadruped=2).get(self.domain, camera_id)
|
_, reward, done, info = self.env.step(action)
|
||||||
return self.env.physics.render(height, width, camera_id)
|
return self._get_obs(), reward, done, info
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg):
|
def make_env(cfg):
|
||||||
@@ -192,9 +103,9 @@ def make_env(cfg):
|
|||||||
task,
|
task,
|
||||||
task_kwargs={'random': cfg.seed},
|
task_kwargs={'random': cfg.seed},
|
||||||
visualize_reward=False)
|
visualize_reward=False)
|
||||||
env = ActionDTypeWrapper(env, np.float32)
|
|
||||||
env = ActionRepeatWrapper(env, 2)
|
|
||||||
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
|
env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
|
||||||
env = ExtendedTimeStepWrapper(env)
|
env = DMControlWrapper(env, domain)
|
||||||
env = TimeStepToGymWrapper(env, domain, task)
|
if cfg.obs == 'rgb':
|
||||||
|
env = Pixels(env, cfg)
|
||||||
|
env = Timeout(env, max_episode_steps=500)
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from envs.wrappers.time_limit import TimeLimit
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
import mani_skill2.envs
|
import mani_skill2.envs
|
||||||
|
|
||||||
@@ -47,9 +47,12 @@ class ManiSkillWrapper(gym.Wrapper):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
reward = 0
|
reward = 0
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
obs, r, _, info = self.env.step(action)
|
obs, r, done, info = self.env.step(action)
|
||||||
reward += r
|
reward += r
|
||||||
return obs, reward, False, info
|
info['terminated'] = done
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self):
|
def unwrapped(self):
|
||||||
@@ -74,6 +77,6 @@ def make_env(cfg):
|
|||||||
render_camera_cfgs=dict(width=384, height=384),
|
render_camera_cfgs=dict(width=384, height=384),
|
||||||
)
|
)
|
||||||
env = ManiSkillWrapper(env, cfg)
|
env = ManiSkillWrapper(env, cfg)
|
||||||
env = TimeLimit(env, max_episode_steps=100)
|
env = Timeout(env, max_episode_steps=100)
|
||||||
env.max_episode_steps = env._max_episode_steps
|
env.max_episode_steps = env._max_episode_steps
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gym
|
||||||
from envs.wrappers.time_limit import TimeLimit
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
|
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
|
||||||
|
|
||||||
@@ -47,6 +47,6 @@ def make_env(cfg):
|
|||||||
assert cfg.obs == 'state', 'This task only supports state observations.'
|
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||||
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
|
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
|
||||||
env = MetaWorldWrapper(env, cfg)
|
env = MetaWorldWrapper(env, cfg)
|
||||||
env = TimeLimit(env, max_episode_steps=100)
|
env = Timeout(env, max_episode_steps=100)
|
||||||
env.max_episode_steps = env._max_episode_steps
|
env.max_episode_steps = env._max_episode_steps
|
||||||
return env
|
return env
|
||||||
|
|||||||
59
tdmpc2/envs/mujoco.py
Normal file
59
tdmpc2/envs/mujoco.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import numpy as np
|
||||||
|
import gymnasium as gym
|
||||||
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
|
|
||||||
|
MUJOCO_TASKS = {
|
||||||
|
'mujoco-walker': 'Walker2d-v4',
|
||||||
|
'mujoco-halfcheetah': 'HalfCheetah-v4',
|
||||||
|
'bipedal-walker': 'BipedalWalker-v3',
|
||||||
|
'lunarlander-continuous': 'LunarLander-v2',
|
||||||
|
}
|
||||||
|
|
||||||
|
class MuJoCoWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, cfg):
|
||||||
|
super().__init__(env)
|
||||||
|
self.env = env
|
||||||
|
self.cfg = cfg
|
||||||
|
self._cumulative_reward = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._cumulative_reward = 0
|
||||||
|
return self.env.reset()[0]
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action.copy())
|
||||||
|
self._cumulative_reward += reward
|
||||||
|
done = terminated or truncated
|
||||||
|
info['terminated'] = terminated
|
||||||
|
if self.cfg.task == 'lunarlander-continuous':
|
||||||
|
info['success'] = self._cumulative_reward > 200
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unwrapped(self):
|
||||||
|
return self.env.unwrapped
|
||||||
|
|
||||||
|
def render(self, **kwargs):
|
||||||
|
return self.env.render(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_env(cfg):
|
||||||
|
"""
|
||||||
|
Make classic/MuJoCo environment.
|
||||||
|
"""
|
||||||
|
if not cfg.task in MUJOCO_TASKS:
|
||||||
|
raise ValueError('Unknown task:', cfg.task)
|
||||||
|
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||||
|
if cfg.task == 'lunarlander-continuous':
|
||||||
|
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
|
||||||
|
else:
|
||||||
|
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
|
||||||
|
env = MuJoCoWrapper(env, cfg)
|
||||||
|
env = Timeout(env, max_episode_steps={
|
||||||
|
'lunarlander-continuous': 500,
|
||||||
|
'bipedal-walker': 1600,
|
||||||
|
}.get(cfg.task, 1000)) # Default max episode steps for other tasks
|
||||||
|
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
|
||||||
|
cfg.rho = 0.7 # TODO: increase rho for episodic tasks since termination always happens at the end of a sequence
|
||||||
|
return env
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import gym
|
import gymnasium as gym
|
||||||
from envs.wrappers.time_limit import TimeLimit
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
|
|
||||||
MYOSUITE_TASKS = {
|
MYOSUITE_TASKS = {
|
||||||
@@ -53,6 +53,6 @@ def make_env(cfg):
|
|||||||
from myosuite.utils import gym as gym_utils
|
from myosuite.utils import gym as gym_utils
|
||||||
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
|
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
|
||||||
env = MyoSuiteWrapper(env, cfg)
|
env = MyoSuiteWrapper(env, cfg)
|
||||||
env = TimeLimit(env, max_episode_steps=100)
|
env = Timeout(env, max_episode_steps=100)
|
||||||
env.max_episode_steps = env._max_episode_steps
|
env.max_episode_steps = env._max_episode_steps
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
from collections import deque
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class PixelWrapper(gym.Wrapper):
|
|
||||||
"""
|
|
||||||
Wrapper for pixel observations. Compatible with DMControl environments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cfg, env, num_frames=3, render_size=64):
|
|
||||||
super().__init__(env)
|
|
||||||
self.cfg = cfg
|
|
||||||
self.env = env
|
|
||||||
self.observation_space = gym.spaces.Box(
|
|
||||||
low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8
|
|
||||||
)
|
|
||||||
self._frames = deque([], maxlen=num_frames)
|
|
||||||
self._render_size = render_size
|
|
||||||
|
|
||||||
def _get_obs(self):
|
|
||||||
frame = self.env.render(
|
|
||||||
mode='rgb_array', width=self._render_size, height=self._render_size
|
|
||||||
).transpose(2, 0, 1)
|
|
||||||
self._frames.append(frame)
|
|
||||||
return torch.from_numpy(np.concatenate(self._frames))
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.env.reset()
|
|
||||||
for _ in range(self._frames.maxlen):
|
|
||||||
obs = self._get_obs()
|
|
||||||
return obs
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
_, reward, done, info = self.env.step(action)
|
|
||||||
return self._get_obs(), reward, done, info
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -17,9 +17,10 @@ class TensorWrapper(gym.Wrapper):
|
|||||||
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
return torch.from_numpy(self.action_space.sample().astype(np.float32))
|
||||||
|
|
||||||
def _try_f32_tensor(self, x):
|
def _try_f32_tensor(self, x):
|
||||||
x = torch.from_numpy(x)
|
if isinstance(x, np.ndarray):
|
||||||
if x.dtype == torch.float64:
|
x = torch.from_numpy(x)
|
||||||
x = x.float()
|
if x.dtype == torch.float64:
|
||||||
|
x = x.float()
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _obs_to_tensor(self, obs):
|
def _obs_to_tensor(self, obs):
|
||||||
@@ -37,4 +38,5 @@ class TensorWrapper(gym.Wrapper):
|
|||||||
obs, reward, done, info = self.env.step(action.numpy())
|
obs, reward, done, info = self.env.step(action.numpy())
|
||||||
info = defaultdict(float, info)
|
info = defaultdict(float, info)
|
||||||
info['success'] = float(info['success'])
|
info['success'] = float(info['success'])
|
||||||
|
info['terminated'] = torch.tensor(float(info['terminated']))
|
||||||
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info
|
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
"""
|
|
||||||
Wrapper for limiting the time steps of an environment.
|
|
||||||
Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py
|
|
||||||
"""
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import gym
|
|
||||||
|
|
||||||
|
|
||||||
class TimeLimit(gym.Wrapper):
|
|
||||||
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
|
|
||||||
|
|
||||||
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
|
|
||||||
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
|
|
||||||
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
|
|
||||||
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
|
|
||||||
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from gym.envs.classic_control import CartPoleEnv
|
|
||||||
>>> from gym.wrappers import TimeLimit
|
|
||||||
>>> env = CartPoleEnv()
|
|
||||||
>>> env = TimeLimit(env, max_episode_steps=1000)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
|
|
||||||
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
env: The environment to apply the wrapper
|
|
||||||
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
|
|
||||||
"""
|
|
||||||
super().__init__(env)
|
|
||||||
if max_episode_steps is None and self.env.spec is not None:
|
|
||||||
max_episode_steps = env.spec.max_episode_steps
|
|
||||||
if self.env.spec is not None:
|
|
||||||
self.env.spec.max_episode_steps = max_episode_steps
|
|
||||||
self._max_episode_steps = max_episode_steps
|
|
||||||
self._elapsed_steps = None
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action: The environment step action
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
|
|
||||||
when truncated (the number of steps elapsed >= max episode steps) or
|
|
||||||
"TimeLimit.truncated"=False if the environment terminated
|
|
||||||
"""
|
|
||||||
observation, reward, done, info = self.env.step(action)
|
|
||||||
self._elapsed_steps += 1
|
|
||||||
if self._elapsed_steps >= self._max_episode_steps:
|
|
||||||
# TimeLimit.truncated key may have been already set by the environment
|
|
||||||
# do not overwrite it
|
|
||||||
episode_truncated = not done or info.get("TimeLimit.truncated", False)
|
|
||||||
info["TimeLimit.truncated"] = episode_truncated
|
|
||||||
done = True
|
|
||||||
return observation, reward, done, info
|
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
|
||||||
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: The kwargs to reset the environment with
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The reset environment
|
|
||||||
"""
|
|
||||||
self._elapsed_steps = 0
|
|
||||||
return self.env.reset(**kwargs)
|
|
||||||
25
tdmpc2/envs/wrappers/timeout.py
Normal file
25
tdmpc2/envs/wrappers/timeout.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
|
class Timeout(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Wrapper for enforcing a time limit on the environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env, max_episode_steps):
|
||||||
|
super().__init__(env)
|
||||||
|
self._max_episode_steps = max_episode_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_episode_steps(self):
|
||||||
|
return self._max_episode_steps
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
self._t = 0
|
||||||
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
self._t += 1
|
||||||
|
done = done or self._t >= self.max_episode_steps
|
||||||
|
return obs, reward, done, info
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
os.environ['MUJOCO_GL'] = 'egl'
|
os.environ['MUJOCO_GL'] = os.getenv("MUJOCO_GL", 'egl')
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|||||||
from common import math
|
from common import math
|
||||||
from common.scale import RunningScale
|
from common.scale import RunningScale
|
||||||
from common.world_model import WorldModel
|
from common.world_model import WorldModel
|
||||||
|
from common.layers import api_model_conversion
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
|
||||||
|
|
||||||
@@ -23,6 +24,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||||
{'params': self.model._dynamics.parameters()},
|
{'params': self.model._dynamics.parameters()},
|
||||||
{'params': self.model._reward.parameters()},
|
{'params': self.model._reward.parameters()},
|
||||||
|
{'params': self.model._termination.parameters() if self.cfg.episodic else []},
|
||||||
{'params': self.model._Qs.parameters()},
|
{'params': self.model._Qs.parameters()},
|
||||||
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
|
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
|
||||||
}
|
}
|
||||||
@@ -34,6 +36,8 @@ class TDMPC2(torch.nn.Module):
|
|||||||
self.discount = torch.tensor(
|
self.discount = torch.tensor(
|
||||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||||
|
print('Episode length:', cfg.episode_length)
|
||||||
|
print('Discount factor:', self.discount)
|
||||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||||
if cfg.compile:
|
if cfg.compile:
|
||||||
print('Compiling update function with torch.compile...')
|
print('Compiling update function with torch.compile...')
|
||||||
@@ -82,8 +86,14 @@ class TDMPC2(torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
fp (str or dict): Filepath or state dict to load.
|
fp (str or dict): Filepath or state dict to load.
|
||||||
"""
|
"""
|
||||||
state_dict = fp if isinstance(fp, dict) else torch.load(fp)
|
if isinstance(fp, dict):
|
||||||
self.model.load_state_dict(state_dict["model"])
|
state_dict = fp
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(fp, map_location=torch.get_default_device(), weights_only=False)
|
||||||
|
state_dict = state_dict["model"] if "model" in state_dict else state_dict
|
||||||
|
state_dict = api_model_conversion(self.model.state_dict(), state_dict)
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
return
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def act(self, obs, t0=False, eval_mode=False, task=None):
|
def act(self, obs, t0=False, eval_mode=False, task=None):
|
||||||
@@ -103,23 +113,28 @@ class TDMPC2(torch.nn.Module):
|
|||||||
if task is not None:
|
if task is not None:
|
||||||
task = torch.tensor([task], device=self.device)
|
task = torch.tensor([task], device=self.device)
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
|
return self.plan(obs, t0=t0, eval_mode=eval_mode, task=task).cpu()
|
||||||
else:
|
z = self.model.encode(obs, task)
|
||||||
z = self.model.encode(obs, task)
|
action, info = self.model.pi(z, task)
|
||||||
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
if eval_mode:
|
||||||
return a.cpu()
|
action = info["mean"]
|
||||||
|
return action[0].cpu()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _estimate_value(self, z, actions, task):
|
def _estimate_value(self, z, actions, task):
|
||||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||||
G, discount = 0, 1
|
G, discount = 0, 1
|
||||||
|
termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
|
||||||
for t in range(self.cfg.horizon):
|
for t in range(self.cfg.horizon):
|
||||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||||
z = self.model.next(z, actions[t], task)
|
z = self.model.next(z, actions[t], task)
|
||||||
G = G + discount * reward
|
G = G + discount * (1-termination) * reward
|
||||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||||
discount = discount * discount_update
|
discount = discount * discount_update
|
||||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
if self.cfg.episodic:
|
||||||
|
termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.)
|
||||||
|
action, _ = self.model.pi(z, task)
|
||||||
|
return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||||
@@ -141,9 +156,9 @@ class TDMPC2(torch.nn.Module):
|
|||||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
||||||
for t in range(self.cfg.horizon-1):
|
for t in range(self.cfg.horizon-1):
|
||||||
pi_actions[t] = self.model.pi(_z, task)[1]
|
pi_actions[t], _ = self.model.pi(_z, task)
|
||||||
_z = self.model.next(_z, pi_actions[t], task)
|
_z = self.model.next(_z, pi_actions[t], task)
|
||||||
pi_actions[-1] = self.model.pi(_z, task)[1]
|
pi_actions[-1], _ = self.model.pi(_z, task)
|
||||||
|
|
||||||
# Initialize state and parameters
|
# Initialize state and parameters
|
||||||
z = z.repeat(self.cfg.num_samples, 1)
|
z = z.repeat(self.cfg.num_samples, 1)
|
||||||
@@ -183,7 +198,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
std = std * self.model._action_masks[task]
|
std = std * self.model._action_masks[task]
|
||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
rand_idx = math.gumbel_softmax_sample(score.squeeze(1))
|
||||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
||||||
a, std = actions[0], std[0]
|
a, std = actions[0], std[0]
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
@@ -202,43 +217,51 @@ class TDMPC2(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
float: Loss of the policy update.
|
float: Loss of the policy update.
|
||||||
"""
|
"""
|
||||||
_, pis, log_pis, _ = self.model.pi(zs, task)
|
action, info = self.model.pi(zs, task)
|
||||||
qs = self.model.Q(zs, pis, task, return_type='avg', detach=True)
|
qs = self.model.Q(zs, action, task, return_type='avg', detach=True)
|
||||||
self.scale.update(qs[0])
|
self.scale.update(qs[0])
|
||||||
qs = self.scale(qs)
|
qs = self.scale(qs)
|
||||||
|
|
||||||
# Loss is a weighted sum of Q-values
|
# Loss is a weighted sum of Q-values
|
||||||
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||||
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
|
pi_loss = (-(self.cfg.entropy_coef * info["scaled_entropy"] + qs).mean(dim=(1,2)) * rho).mean()
|
||||||
pi_loss.backward()
|
pi_loss.backward()
|
||||||
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
||||||
self.pi_optim.step()
|
self.pi_optim.step()
|
||||||
self.pi_optim.zero_grad(set_to_none=True)
|
self.pi_optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
return pi_loss.detach(), pi_grad_norm
|
info = TensorDict({
|
||||||
|
"pi_loss": pi_loss,
|
||||||
|
"pi_grad_norm": pi_grad_norm,
|
||||||
|
"pi_entropy": info["entropy"],
|
||||||
|
"pi_scaled_entropy": info["scaled_entropy"],
|
||||||
|
"pi_scale": self.scale.value,
|
||||||
|
})
|
||||||
|
return info
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _td_target(self, next_z, reward, task):
|
def _td_target(self, next_z, reward, terminated, task):
|
||||||
"""
|
"""
|
||||||
Compute the TD-target from a reward and the observation at the following time step.
|
Compute the TD-target from a reward and the observation at the following time step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
next_z (torch.Tensor): Latent state at the following time step.
|
next_z (torch.Tensor): Latent state at the following time step.
|
||||||
reward (torch.Tensor): Reward at the current time step.
|
reward (torch.Tensor): Reward at the current time step.
|
||||||
|
terminated (torch.Tensor): Termination signal at the current time step.
|
||||||
task (torch.Tensor): Task index (only used for multi-task experiments).
|
task (torch.Tensor): Task index (only used for multi-task experiments).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: TD-target.
|
torch.Tensor: TD-target.
|
||||||
"""
|
"""
|
||||||
pi = self.model.pi(next_z, task)[1]
|
action, _ = self.model.pi(next_z, task)
|
||||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||||
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
|
||||||
|
|
||||||
def _update(self, obs, action, reward, task=None):
|
def _update(self, obs, action, reward, terminated, task=None):
|
||||||
# Compute targets
|
# Compute targets
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_z = self.model.encode(obs[1:], task)
|
next_z = self.model.encode(obs[1:], task)
|
||||||
td_targets = self._td_target(next_z, reward, task)
|
td_targets = self._td_target(next_z, reward, terminated, task)
|
||||||
|
|
||||||
# Prepare for update
|
# Prepare for update
|
||||||
self.model.train()
|
self.model.train()
|
||||||
@@ -257,6 +280,8 @@ class TDMPC2(torch.nn.Module):
|
|||||||
_zs = zs[:-1]
|
_zs = zs[:-1]
|
||||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||||
reward_preds = self.model.reward(_zs, action, task)
|
reward_preds = self.model.reward(_zs, action, task)
|
||||||
|
if self.cfg.episodic:
|
||||||
|
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
reward_loss, value_loss = 0, 0
|
reward_loss, value_loss = 0, 0
|
||||||
@@ -267,10 +292,15 @@ class TDMPC2(torch.nn.Module):
|
|||||||
|
|
||||||
consistency_loss = consistency_loss / self.cfg.horizon
|
consistency_loss = consistency_loss / self.cfg.horizon
|
||||||
reward_loss = reward_loss / self.cfg.horizon
|
reward_loss = reward_loss / self.cfg.horizon
|
||||||
|
if self.cfg.episodic:
|
||||||
|
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
|
||||||
|
else:
|
||||||
|
termination_loss = 0.
|
||||||
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
||||||
total_loss = (
|
total_loss = (
|
||||||
self.cfg.consistency_coef * consistency_loss +
|
self.cfg.consistency_coef * consistency_loss +
|
||||||
self.cfg.reward_coef * reward_loss +
|
self.cfg.reward_coef * reward_loss +
|
||||||
|
self.cfg.termination_coef * termination_loss +
|
||||||
self.cfg.value_coef * value_loss
|
self.cfg.value_coef * value_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -281,23 +311,25 @@ class TDMPC2(torch.nn.Module):
|
|||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# Update policy
|
# Update policy
|
||||||
pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task)
|
pi_info = self.update_pi(zs.detach(), task)
|
||||||
|
|
||||||
# Update target Q-functions
|
# Update target Q-functions
|
||||||
self.model.soft_update_target_Q()
|
self.model.soft_update_target_Q()
|
||||||
|
|
||||||
# Return training statistics
|
# Return training statistics
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
return TensorDict({
|
info = TensorDict({
|
||||||
"consistency_loss": consistency_loss,
|
"consistency_loss": consistency_loss,
|
||||||
"reward_loss": reward_loss,
|
"reward_loss": reward_loss,
|
||||||
"value_loss": value_loss,
|
"value_loss": value_loss,
|
||||||
"pi_loss": pi_loss,
|
"termination_loss": termination_loss,
|
||||||
"total_loss": total_loss,
|
"total_loss": total_loss,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
"pi_grad_norm": pi_grad_norm,
|
})
|
||||||
"pi_scale": self.scale.value,
|
if self.cfg.episodic:
|
||||||
}).detach().mean()
|
info.update(math.termination_statistics(torch.sigmoid(termination_pred[-1]), terminated[-1]))
|
||||||
|
info.update(pi_info)
|
||||||
|
return info.detach().mean()
|
||||||
|
|
||||||
def update(self, buffer):
|
def update(self, buffer):
|
||||||
"""
|
"""
|
||||||
@@ -309,9 +341,9 @@ class TDMPC2(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Dictionary of training statistics.
|
dict: Dictionary of training statistics.
|
||||||
"""
|
"""
|
||||||
obs, action, reward, task = buffer.sample()
|
obs, action, reward, terminated, task = buffer.sample()
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if task is not None:
|
if task is not None:
|
||||||
kwargs["task"] = task
|
kwargs["task"] = task
|
||||||
torch.compiler.cudagraph_mark_step_begin()
|
torch.compiler.cudagraph_mark_step_begin()
|
||||||
return self._update(obs, action, reward, **kwargs)
|
return self._update(obs, action, reward, terminated, **kwargs)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
os.environ['MUJOCO_GL'] = 'egl'
|
os.environ['MUJOCO_GL'] = os.getenv("MUJOCO_GL", 'egl')
|
||||||
os.environ['LAZY_LEGACY_OP'] = '0'
|
os.environ['LAZY_LEGACY_OP'] = '0'
|
||||||
os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1"
|
os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1"
|
||||||
os.environ['TORCH_LOGS'] = "+recompiles"
|
os.environ['TORCH_LOGS'] = "+recompiles"
|
||||||
|
|||||||
@@ -38,19 +38,15 @@ class OfflineTrainer(Trainer):
|
|||||||
f'episode_reward+{self.cfg.tasks[task_idx]}': np.nanmean(ep_rewards),
|
f'episode_reward+{self.cfg.tasks[task_idx]}': np.nanmean(ep_rewards),
|
||||||
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
|
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def train(self):
|
def _load_dataset(self):
|
||||||
"""Train a TD-MPC2 agent."""
|
"""Load dataset for offline training."""
|
||||||
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
|
|
||||||
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
|
||||||
|
|
||||||
# Load data
|
|
||||||
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
||||||
fps = sorted(glob(str(fp)))
|
fps = sorted(glob(str(fp)))
|
||||||
assert len(fps) > 0, f'No data found at {fp}'
|
assert len(fps) > 0, f'No data found at {fp}'
|
||||||
print(f'Found {len(fps)} files in {fp}')
|
print(f'Found {len(fps)} files in {fp}')
|
||||||
assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \
|
if len(fps) < (20 if self.cfg.task == 'mt80' else 4):
|
||||||
f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.'
|
print(f'WARNING: expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.')
|
||||||
|
|
||||||
# Create buffer for sampling
|
# Create buffer for sampling
|
||||||
_cfg = deepcopy(self.cfg)
|
_cfg = deepcopy(self.cfg)
|
||||||
@@ -59,15 +55,20 @@ class OfflineTrainer(Trainer):
|
|||||||
_cfg.steps = _cfg.buffer_size
|
_cfg.steps = _cfg.buffer_size
|
||||||
self.buffer = Buffer(_cfg)
|
self.buffer = Buffer(_cfg)
|
||||||
for fp in tqdm(fps, desc='Loading data'):
|
for fp in tqdm(fps, desc='Loading data'):
|
||||||
td = torch.load(fp)
|
td = torch.load(fp, weights_only=False)
|
||||||
assert td.shape[1] == _cfg.episode_length, \
|
assert td.shape[1] == _cfg.episode_length, \
|
||||||
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
||||||
f'please double-check your config.'
|
f'please double-check your config.'
|
||||||
for i in range(len(td)):
|
self.buffer.load(td)
|
||||||
self.buffer.add(td[i])
|
|
||||||
expected_episodes = _cfg.buffer_size // _cfg.episode_length
|
expected_episodes = _cfg.buffer_size // _cfg.episode_length
|
||||||
assert self.buffer.num_eps == expected_episodes, \
|
if self.buffer.num_eps != expected_episodes:
|
||||||
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
|
print(f'WARNING: buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes for {self.cfg.task} task set.')
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""Train a TD-MPC2 agent."""
|
||||||
|
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
|
||||||
|
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
||||||
|
self._load_dataset()
|
||||||
|
|
||||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||||
metrics = {}
|
metrics = {}
|
||||||
@@ -80,7 +81,7 @@ class OfflineTrainer(Trainer):
|
|||||||
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
|
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
|
||||||
metrics = {
|
metrics = {
|
||||||
'iteration': i,
|
'iteration': i,
|
||||||
'total_time': time() - self._start_time,
|
'elapsed_time': time() - self._start_time,
|
||||||
}
|
}
|
||||||
metrics.update(train_metrics)
|
metrics.update(train_metrics)
|
||||||
if i % self.cfg.eval_freq == 0:
|
if i % self.cfg.eval_freq == 0:
|
||||||
|
|||||||
@@ -17,15 +17,17 @@ class OnlineTrainer(Trainer):
|
|||||||
|
|
||||||
def common_metrics(self):
|
def common_metrics(self):
|
||||||
"""Return a dictionary of current metrics."""
|
"""Return a dictionary of current metrics."""
|
||||||
|
elapsed_time = time() - self._start_time
|
||||||
return dict(
|
return dict(
|
||||||
step=self._step,
|
step=self._step,
|
||||||
episode=self._ep_idx,
|
episode=self._ep_idx,
|
||||||
total_time=time() - self._start_time,
|
elapsed_time=elapsed_time,
|
||||||
|
steps_per_second=self._step / elapsed_time
|
||||||
)
|
)
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
"""Evaluate a TD-MPC2 agent."""
|
"""Evaluate a TD-MPC2 agent."""
|
||||||
ep_rewards, ep_successes = [], []
|
ep_rewards, ep_successes, ep_lengths = [], [], []
|
||||||
for i in range(self.cfg.eval_episodes):
|
for i in range(self.cfg.eval_episodes):
|
||||||
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
|
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
|
||||||
if self.cfg.save_video:
|
if self.cfg.save_video:
|
||||||
@@ -40,14 +42,16 @@ class OnlineTrainer(Trainer):
|
|||||||
self.logger.video.record(self.env)
|
self.logger.video.record(self.env)
|
||||||
ep_rewards.append(ep_reward)
|
ep_rewards.append(ep_reward)
|
||||||
ep_successes.append(info['success'])
|
ep_successes.append(info['success'])
|
||||||
|
ep_lengths.append(t)
|
||||||
if self.cfg.save_video:
|
if self.cfg.save_video:
|
||||||
self.logger.video.save(self._step)
|
self.logger.video.save(self._step)
|
||||||
return dict(
|
return dict(
|
||||||
episode_reward=np.nanmean(ep_rewards),
|
episode_reward=np.nanmean(ep_rewards),
|
||||||
episode_success=np.nanmean(ep_successes),
|
episode_success=np.nanmean(ep_successes),
|
||||||
|
episode_length= np.nanmean(ep_lengths),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_td(self, obs, action=None, reward=None):
|
def to_td(self, obs, action=None, reward=None, terminated=None):
|
||||||
"""Creates a TensorDict for a new episode."""
|
"""Creates a TensorDict for a new episode."""
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
obs = TensorDict(obs, batch_size=(), device='cpu')
|
obs = TensorDict(obs, batch_size=(), device='cpu')
|
||||||
@@ -57,10 +61,13 @@ class OnlineTrainer(Trainer):
|
|||||||
action = torch.full_like(self.env.rand_act(), float('nan'))
|
action = torch.full_like(self.env.rand_act(), float('nan'))
|
||||||
if reward is None:
|
if reward is None:
|
||||||
reward = torch.tensor(float('nan'))
|
reward = torch.tensor(float('nan'))
|
||||||
|
if terminated is None:
|
||||||
|
terminated = torch.tensor(float('nan'))
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
obs=obs,
|
obs=obs,
|
||||||
action=action.unsqueeze(0),
|
action=action.unsqueeze(0),
|
||||||
reward=reward.unsqueeze(0),
|
reward=reward.unsqueeze(0),
|
||||||
|
terminated=terminated.unsqueeze(0),
|
||||||
batch_size=(1,))
|
batch_size=(1,))
|
||||||
return td
|
return td
|
||||||
|
|
||||||
@@ -81,10 +88,14 @@ class OnlineTrainer(Trainer):
|
|||||||
eval_next = False
|
eval_next = False
|
||||||
|
|
||||||
if self._step > 0:
|
if self._step > 0:
|
||||||
|
if info['terminated'] and not self.cfg.episodic:
|
||||||
|
raise ValueError('Termination detected but you are not in episodic mode. ' \
|
||||||
|
'Set `episodic=true` to enable support for terminations.')
|
||||||
train_metrics.update(
|
train_metrics.update(
|
||||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||||
episode_success=info['success'],
|
episode_success=info['success'],
|
||||||
)
|
episode_length=len(self._tds),
|
||||||
|
episode_terminated=info['terminated'])
|
||||||
train_metrics.update(self.common_metrics())
|
train_metrics.update(self.common_metrics())
|
||||||
self.logger.log(train_metrics, 'train')
|
self.logger.log(train_metrics, 'train')
|
||||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||||
@@ -98,7 +109,7 @@ class OnlineTrainer(Trainer):
|
|||||||
else:
|
else:
|
||||||
action = self.env.rand_act()
|
action = self.env.rand_act()
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
self._tds.append(self.to_td(obs, action, reward))
|
self._tds.append(self.to_td(obs, action, reward, info['terminated']))
|
||||||
|
|
||||||
# Update agent
|
# Update agent
|
||||||
if self._step >= self.cfg.seed_steps:
|
if self._step >= self.cfg.seed_steps:
|
||||||
|
|||||||
Reference in New Issue
Block a user