50 Commits

Author SHA1 Message Date
Nicklas Hansen
97c1447199 minor updates to vectorization 2025-05-21 16:06:45 -07:00
Nicklas Hansen
a586d8f393 fix merge error 2025-05-20 14:09:13 -07:00
Nicklas Hansen
6116eb3fa5 fix merge error 2025-05-20 13:59:12 -07:00
Nicklas Hansen
491d367fc6 Merge branch 'vectorized_env' of github.com:nicklashansen/tdmpc2 into vectorized_env 2025-05-20 13:44:41 -07:00
Nicklas Hansen
10f368f20d init 2025-05-20 13:42:28 -07:00
Nicklas Hansen
829e329b3b init 2025-05-20 13:40:02 -07:00
Sue Hyun Park
8bbc14ebab Fix: handle _action_masks buffer in single-task scenarios (#67) 2025-05-19 20:12:12 -07:00
Nicklas Hansen
7992fa193e update readme + dockerfile 2025-05-13 14:51:16 -07:00
Nicklas Hansen
7ec6bc83a8 only instantiate termination pred head if episodic=true 2025-05-02 16:51:24 -07:00
Nicklas Hansen
38b31a5d72 only instantiate termination pred head if episodic=true 2025-05-02 16:24:02 -07:00
Nicklas Hansen
7942e9082b update readme + clean up 2025-04-15 16:32:15 -07:00
Nicklas Hansen
eece80123d full support for episodic rl 2025-04-15 15:55:05 -07:00
Nicklas Hansen
38f853efc4 clean up 2025-04-15 10:16:02 -07:00
Nicklas Hansen
62be41ab58 experimental changes to termination prediction 2025-04-10 00:32:13 -07:00
Nicklas Hansen
c95b755655 add walker2d 2025-04-09 15:55:57 -07:00
Nicklas Hansen
81eb17068e QoL improvements to termination signal debugging 2025-04-08 19:15:31 -07:00
Nicklas Hansen
add30b5a74 merge main into branch + fix termination in td-targets 2025-04-08 12:40:10 -07:00
Nicklas Hansen
0a914570dc fix multitask model api conversion 2025-02-27 16:25:21 -08:00
Nicklas Hansen
55bde9745f update torchrl version 2025-02-05 16:46:34 -08:00
Nicklas Hansen
5ced6dfeb4 auto-convert old checkpoints to new format 2025-02-05 16:26:19 -08:00
Nicklas Hansen
dddc226d25 partial fix to loading checkpoints 2025-01-21 00:10:53 -08:00
Vincent Moens
ae4238946f Conversion tools for state-dicts (#55)
* init

* init

* amend
2025-01-20 15:49:36 -08:00
Nicklas Hansen
a19f91c0b5 enable torch.compile for offline rl + rgb inputs 2024-12-25 12:22:39 -08:00
Nicklas Hansen
e452ca7539 factor pi outputs 2024-12-25 12:08:07 -08:00
Nicklas Hansen
db1865334e refactor pi outputs 2024-12-25 12:02:33 -08:00
Nicklas Hansen
804f9b3949 refactor pi outputs 2024-12-24 03:05:00 -08:00
Nicklas Hansen
66f8c21f58 cache buffer values in offline training 2024-12-19 09:40:04 -08:00
Nicklas Hansen
9cac7c5775 faster offline data loading 2024-12-19 06:52:31 -08:00
Nicklas Hansen
df8a465c8e update offline trainer to use new torch.load api 2024-12-10 16:30:05 -08:00
Nicklas Hansen
2e27fbb6f4 partial fix for loading old checkpoints 2024-12-10 16:04:27 -08:00
Nicklas Hansen
6117bc427d simplify dmcontrol wrappers + upgrade to gymnasium==0.29.1 2024-12-10 15:16:34 -08:00
Nicklas Hansen
32fc2bdf93 refactor policy 2024-12-03 12:22:02 -08:00
Nicklas Hansen
10a0be2724 fix indexing 2024-11-10 23:16:32 -08:00
Nicklas Hansen
ad2342e258 Merge branch 'vectorized_env' of github.com:nicklashansen/tdmpc2 into vectorized_env 2024-11-10 13:04:54 -08:00
Nicklas Hansen
fa41a3e450 init 2024-11-10 13:00:56 -08:00
Nicklas Hansen
3789fcd5b8 update pinned torchrl version 2024-07-02 10:12:06 -07:00
Nicklas Hansen
d51feb0e9f Update README.md 2024-07-02 10:12:06 -07:00
Nicklas Hansen
2dc668ecaf reduce # wandb calls 2024-07-02 10:12:06 -07:00
Nicklas Hansen
39be86fc52 update dockerfile 2024-07-02 10:12:06 -07:00
Nicklas Hansen
f0af4b6b27 update dockerfile + pin all versions 2024-07-02 10:12:06 -07:00
Nicklas Hansen
a2d9d0c8ff minor fix in print 2024-07-02 10:12:06 -07:00
Nicklas Hansen
ab43880945 migrate to slicebuffer from torchrl-nightly 2024-07-02 10:12:06 -07:00
Nicklas Hansen
f6d1bfe12d update pinned torchrl version 2024-07-02 10:11:13 -07:00
Nicklas Hansen
9dd3e673c4 clean up 2024-02-11 14:44:16 -08:00
Nicklas Hansen
51d6b8d7a9 init 2024-02-11 14:41:20 -08:00
Nicklas Hansen
ff02f41e73 fix 2024-01-08 17:18:22 -08:00
Nicklas Hansen
e86c343a67 Merge branch 'episodic-rl' of github.com:nicklashansen/tdmpc2 into episodic-rl 2024-01-08 10:55:10 -08:00
Nicklas Hansen
cc62c4c9ce init 2024-01-08 10:51:27 -08:00
Nicklas Hansen
fabf01a5ec solves episodic variant of cartpole-balance-sparse 2024-01-07 19:28:41 -08:00
Nicklas Hansen
26c72119cd init 2024-01-07 18:16:33 -08:00
29 changed files with 653 additions and 671 deletions

View File

@@ -12,13 +12,9 @@ Official implementation of
---- ----
**Discrete branch:** this branch is under active development and contains experimental support for discrete action spaces. We expect a stable release to be available in a few months. Please use the `main` branch for the time being. **Announcement (Apr 2025): support for episodic tasks!**
---- We have added support for episodic RL (tasks with terminations) in the latest release. This functionality can be enabled with `episodic=true` but remains disabled by default to ensure reproducibility of results across releases.
**Announcement: training just got ~4.5x faster!**
Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. `compile=true` is available in state-based online RL at the moment, and we expect to roll out support across all settings in the coming months. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility!
---- ----
@@ -40,19 +36,18 @@ You will need a machine with a GPU and at least 12 GB of RAM for single-task onl
We provide a `Dockerfile` for easy installation. You can build the docker image by running We provide a `Dockerfile` for easy installation. You can build the docker image by running
``` ```
cd docker && docker build . -t <user>/tdmpc2:1.0.0 cd docker && docker build . -t <user>/tdmpc2:1.0.1
``` ```
This docker image contains all dependencies needed for running DMControl, Meta-World, and ManiSkill2 experiments. This docker image contains all dependencies needed for running DMControl. We also provide a pre-built docker image [here](https://hub.docker.com/repository/docker/nicklashansen/tdmpc2/tags/1.0.1/sha256-b07d4e04d4b28ffd9a63ac18ec1541950e874bb51d276c7d09b36135f170dd93).
If you prefer to install dependencies manually, start by installing dependencies via `conda` by running the following command: If you prefer to use `conda` rather than docker, start by running the following command:
``` ```
conda env create -f docker/environment.yaml conda env create -f docker/environment.yaml
pip install gym==0.21.0
``` ```
The `environment.yaml` file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in `environment.yaml`. The `docker/environment.yaml` file installs dependencies required for training on DMControl tasks. Other domains can be installed by following the instructions in `docker/environment.yaml`.
If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running If you want to run ManiSkill2, you will additionally need to download and link the necessary assets by running
@@ -66,19 +61,19 @@ which downloads assets to `./data`. You may move these assets to any location. T
export MS2_ASSET_DIR=<path>/<to>/<data> export MS2_ASSET_DIR=<path>/<to>/<data>
``` ```
and restart your terminal. Meta-World additionally requires MuJoCo 2.1.0. We host the unrestricted MuJoCo 2.1.0 license (courtesy of Google DeepMind) at [https://www.tdmpc2.com/files/mjkey.txt](https://www.tdmpc2.com/files/mjkey.txt). You can download the license by running and restart your terminal. Note that Meta-World requires MuJoCo 2.1.0 and `gym==0.21.0` which is becoming increasingly difficult to install. We host the unrestricted MuJoCo 2.1.0 license (courtesy of Google DeepMind) at [https://www.tdmpc2.com/files/mjkey.txt](https://www.tdmpc2.com/files/mjkey.txt). You can download the license by running
``` ```
wget https://www.tdmpc2.com/files/mjkey.txt -O ~/.mujoco/mjkey.txt wget https://www.tdmpc2.com/files/mjkey.txt -O ~/.mujoco/mjkey.txt
``` ```
See `docker/Dockerfile` for installation instructions if you do not already have MuJoCo 2.1.0 installed. MyoSuite requires `gym==0.13.0` which is incompatible with Meta-World and ManiSkill2. Install separately with `pip install myosuite` if desired. Depending on your existing system packages, you may need to install other dependencies. See `docker/Dockerfile` for a list of recommended system packages. Depending on your existing system packages, you may need to install other dependencies. See `docker/Dockerfile` for a list of recommended system packages.
---- ----
## Supported tasks ## Supported tasks
This codebase currently supports **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite**. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain: This codebase provides support for all **104** continuous control tasks from **DMControl**, **Meta-World**, **ManiSkill2**, and **MyoSuite** used in our paper. Specifically, it supports 39 tasks from DMControl (including 11 custom tasks), 50 tasks from Meta-World, 5 tasks from ManiSkill2, and 10 tasks from MyoSuite, and covers all tasks used in the paper. See below table for expected name formatting for each task domain:
| domain | task | domain | task
| --- | --- | | --- | --- |
@@ -91,9 +86,9 @@ This codebase currently supports **104** continuous control tasks from **DMContr
| myosuite | myo-key-turn | myosuite | myo-key-turn
| myosuite | myo-key-turn-hard | myosuite | myo-key-turn-hard
which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. which can be run by specifying the `task` argument for `evaluation.py`. Multi-task training and evaluation is specified by setting `task=mt80` or `task=mt30` for the 80-task and 30-task sets, respectively. While you generally do not need to access the underlying task IDs or embeddings during training or evaluation of our multi-task models, the mapping from task name to task embedding used in our work can be found [here](https://github.com/nicklashansen/tdmpc2/blob/7ec6bc83a82a5188ca3faddc59aea83f430ab570/tdmpc2/common/__init__.py#L26). As of April 2025, our codebase also provides basic support for other MuJoCo/Box2d Gymnasium tasks; refer to the `envs` directory for a list of tasks. It should be relatively straightforward to add support for custom tasks by following the examples in `envs`.
**As of Dec 27, 2023 the TD-MPC2 codebase also supports pixel observations for DMControl tasks**; use argument `obs=rgb` if you wish to train visual policies. **Note:** we also provide support for image observations in the DMControl tasks. Use argument `obs=rgb` if you wish to train visual policies.
## Example usage ## Example usage
@@ -125,8 +120,6 @@ $ python train.py task=walker-walk obs=rgb
We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments. We recommend using default hyperparameters for single-task online RL, including the default model size of 5M parameters (`model_size=5`). Multi-task offline RL benefits from a larger model size, but larger models are also increasingly costly to train and evaluate. Available arguments are `model_size={1, 5, 19, 48, 317}`. See `config.yaml` for a full list of arguments.
**As of Jan 7, 2024 the TD-MPC2 codebase also supports multi-GPU training for multi-task offline RL experiments**; use branch `distributed` and argument `world_size=N` to train on `N` GPUs. We cannot guarantee that distributed training will yield the same results, but they appear to be similar based on our limited testing.
---- ----
## Citation ## Citation

View File

@@ -4,14 +4,14 @@
# https://www.tdmpc2.com # # https://www.tdmpc2.com #
# -------------------------------------- # # -------------------------------------- #
# Build instructions: # # Build instructions: #
# docker build . -t <user>/tdmpc2:1.0.0 # # docker build . -t <user>/tdmpc2:1.0.1 #
# docker push <user>/tdmpc2:1.0.0 # # docker push <user>/tdmpc2:1.0.1 #
# -------------------------------------- # # -------------------------------------- #
# Run: # # Run: #
# docker run -i \ # # docker run -i \ #
# -v <path>/<to>/tdmpc2:/tdmpc2 \ # # -v <path>/<to>/tdmpc2:/tdmpc2 \ #
# --gpus all \ # # --gpus all \ #
# -t <user>/tdmpc2:1.0.0 \ # # -t <user>/tdmpc2:1.0.1 \ #
# /bin/bash # # /bin/bash #
########################################## ##########################################
@@ -42,34 +42,8 @@ RUN conda update conda && \
conda init conda init
SHELL ["/bin/bash", "-c"] SHELL ["/bin/bash", "-c"]
RUN echo "cd /root" >> /root/.bashrc RUN echo "cd /root" >> /root/.bashrc
# image does not include metaworld, maniskill, myosuite
# mujoco 2.1.0 # these can be installed separately; see environment.yaml for details
ENV MUJOCO_GL egl
ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
RUN mkdir -p /root/.mujoco && \
wget https://www.tdmpc2.com/files/mjkey.txt && \
wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \
tar -xzf mujoco210-linux-x86_64.tar.gz && \
rm mujoco210-linux-x86_64.tar.gz && \
mv mujoco210 /root/.mujoco/mujoco210 && \
mv mjkey.txt /root/.mujoco/mjkey.txt && \
find /root/.mujoco -uid 421709 -exec chown root:root {} \; && \
python -c "import mujoco_py"
# gym
RUN pip install gym==0.21.0
# metaworld
RUN pip install git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
# maniskill2
ENV MS2_ASSET_DIR /root/data
RUN pip install mani-skill2==0.4.1 && \
cd /root && \
python -m mani_skill2.utils.download_asset all -y
# myosuite (conflicts with meta-world / mani-skill2)
# RUN pip install myosuite
# success! # success!
RUN echo "Successfully built TD-MPC2 Docker image!" RUN echo "Successfully built TD-MPC2 Docker image!"

View File

@@ -13,10 +13,9 @@ dependencies:
- pytorch-cuda=12.4 - pytorch-cuda=12.4
- torchvision=0.15.2 - torchvision=0.15.2
- pip: - pip:
- absl-py==2.1.0 - dm-control==1.0.16
- "cython<3"
- dm-control==1.0.8
- glfw==2.7.0 - glfw==2.7.0
- gymnasium==0.29.1
- ffmpeg==1.4 - ffmpeg==1.4
- imageio==2.34.1 - imageio==2.34.1
- imageio-ffmpeg==0.4.9 - imageio-ffmpeg==0.4.9
@@ -24,24 +23,25 @@ dependencies:
- hydra-core==1.3.2 - hydra-core==1.3.2
- hydra-submitit-launcher==1.2.0 - hydra-submitit-launcher==1.2.0
- submitit==1.5.1 - submitit==1.5.1
- setuptools==65.5.0
- patchelf==0.17.2.1
- omegaconf==2.3.0 - omegaconf==2.3.0
- moviepy==1.0.3 - moviepy==1.0.3
- mujoco==2.3.1 - mujoco==3.1.2
- mujoco-py==2.1.2.14
- numpy==1.24.4 - numpy==1.24.4
- tensordict-nightly==2024.11.14 - tensordict-nightly==2025.1.1
- torchrl-nightly==2024.11.14 - torchrl-nightly==2025.1.1
- kornia==0.7.2 - kornia==0.7.2
- termcolor==2.4.0 - termcolor==2.4.0
- tqdm==4.66.4 - tqdm==4.66.4
- pandas==2.0.3 - pandas==2.0.3
- wandb==0.17.4 - wandb==0.17.4
- wheel==0.38.0
#################### ####################
# Gym: # Gym:
# (unmaintained but required for maniskill2/meta-world) # (unmaintained but required for maniskill2/meta-world)
# - "cython<3"
# - wheel==0.38.0
# - setuptools==65.5.0
# - mujoco==2.3.1
# - mujoco-py==2.1.2.14
# - gym==0.21.0 # - gym==0.21.0
#################### ####################
# ManiSkill2: # ManiSkill2:
@@ -55,3 +55,7 @@ dependencies:
# MyoSuite: # MyoSuite:
# - myosuite # - myosuite
#################### ####################
# Classic MuJoCo/Box2d:
# - swig
# - gymnasium[box2d]
####################

View File

@@ -20,6 +20,7 @@ class Buffer():
traj_key='episode', traj_key='episode',
truncated_key=None, truncated_key=None,
strict_length=True, strict_length=True,
cache_values=cfg.multitask,
) )
self._batch_size = cfg.batch_size * (cfg.horizon+1) self._batch_size = cfg.batch_size * (cfg.horizon+1)
self._num_eps = 0 self._num_eps = 0
@@ -65,28 +66,50 @@ class Buffer():
LazyTensorStorage(self._capacity, device=self._storage_device) LazyTensorStorage(self._capacity, device=self._storage_device)
) )
def load(self, td):
"""
Load a batch of episodes into the buffer. This is useful for loading data from disk,
and is more efficient than adding episodes one by one.
"""
num_new_eps = len(td)
episode_idx = torch.arange(self._num_eps, self._num_eps+num_new_eps, dtype=torch.int64)
td['episode'] = episode_idx.unsqueeze(-1).expand(-1, td['reward'].shape[1])
if self._num_eps == 0:
self._buffer = self._init(td[0])
td = td.reshape(td.shape[0]*td.shape[1])
self._buffer.extend(td)
self._num_eps += num_new_eps
return self._num_eps
def add(self, td):
"""Add an episode to the buffer."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * torch.arange(self._num_eps, self._num_eps+self.cfg.num_envs)
td = td.permute(1, 0)
if self._num_eps == 0:
self._buffer = self._init(td[0])
for i in range(self.cfg.num_envs):
self._buffer.extend(td[i])
self._num_eps += self.cfg.num_envs
return self._num_eps
def _prepare_batch(self, td): def _prepare_batch(self, td):
""" """
Prepare a sampled batch for training (post-processing). Prepare a sampled batch for training (post-processing).
Expects `td` to be a TensorDict with batch size TxB. Expects `td` to be a TensorDict with batch size TxB.
""" """
td = td.select("obs", "action", "reward", "task", strict=False).to(self._device, non_blocking=True) td = td.select("obs", "action", "reward", "terminated", "task", strict=False).to(self._device, non_blocking=True)
obs = td.get('obs').contiguous() obs = td.get('obs').contiguous()
action = td.get('action')[1:].contiguous() action = td.get('action')[1:].contiguous()
reward = td.get('reward')[1:].unsqueeze(-1).contiguous() reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
terminated = td.get('terminated', None)
if terminated is not None:
terminated = td.get('terminated')[1:].unsqueeze(-1).contiguous()
else:
terminated = torch.zeros_like(reward)
task = td.get('task', None) task = td.get('task', None)
if task is not None: if task is not None:
task = task[0].contiguous() task = task[0].contiguous()
return obs, action, reward, task return obs, action, reward, terminated, task
def add(self, td):
"""Add an episode to the buffer."""
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
if self._num_eps == 0:
self._buffer = self._init(td)
self._buffer.extend(td)
self._num_eps += 1
return self._num_eps
def sample(self): def sample(self):
"""Sample a batch of subsequences from the buffer.""" """Sample a batch of subsequences from the buffer."""

View File

@@ -4,6 +4,7 @@ import torch.nn.functional as F
from tensordict import from_modules from tensordict import from_modules
from copy import deepcopy from copy import deepcopy
class Ensemble(nn.Module): class Ensemble(nn.Module):
""" """
Vectorized ensemble of modules. Vectorized ensemble of modules.
@@ -15,7 +16,11 @@ class Ensemble(nn.Module):
self.params = from_modules(*modules, as_module=True) self.params = from_modules(*modules, as_module=True)
with self.params[0].data.to("meta").to_module(modules[0]): with self.params[0].data.to("meta").to_module(modules[0]):
self.module = deepcopy(modules[0]) self.module = deepcopy(modules[0])
self._repr = str(modules) self._repr = str(modules[0])
self._n = len(modules)
def __len__(self):
return self._n
def _call(self, params, *args, **kwargs): def _call(self, params, *args, **kwargs):
with params.to_module(self.module): with params.to_module(self.module):
@@ -25,7 +30,7 @@ class Ensemble(nn.Module):
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs) return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
def __repr__(self): def __repr__(self):
return 'Vectorized ' + self._repr return f'Vectorized {len(self)}x ' + self._repr
class ShiftAug(nn.Module): class ShiftAug(nn.Module):
@@ -157,3 +162,60 @@ def enc(cfg, out={}):
else: else:
raise NotImplementedError(f"Encoder for observation type {k} not implemented.") raise NotImplementedError(f"Encoder for observation type {k} not implemented.")
return nn.ModuleDict(out) return nn.ModuleDict(out)
def api_model_conversion(target_state_dict, source_state_dict):
"""
Converts a checkpoint from our old API to the new torch.compile compatible API.
"""
# check whether checkpoint is already in the new format
if "_detach_Qs_params.0.weight" in source_state_dict:
return source_state_dict
name_map = ['weight', 'bias', 'ln.weight', 'ln.bias']
new_state_dict = dict()
# rename keys
for key, val in list(source_state_dict.items()):
if key.startswith('_Qs.'):
num = key[len('_Qs.params.'):]
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
new_total_key = "_Qs.params." + new_key
del source_state_dict[key]
new_state_dict[new_total_key] = val
new_total_key = "_detach_Qs_params." + new_key
new_state_dict[new_total_key] = val
elif key.startswith('_target_Qs.'):
num = key[len('_target_Qs.params.'):]
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
new_total_key = "_target_Qs_params." + new_key
del source_state_dict[key]
new_state_dict[new_total_key] = val
# add batch_size and device from target_state_dict to new_state_dict
for prefix in ('_Qs.', '_detach_Qs_', '_target_Qs_'):
for key in ('__batch_size', '__device'):
new_key = prefix + 'params.' + key
new_state_dict[new_key] = target_state_dict[new_key]
# check that every key in new_state_dict is in target_state_dict
for key in new_state_dict.keys():
assert key in target_state_dict, f"key {key} not in target_state_dict"
# check that all Qs keys in target_state_dict are in new_state_dict
for key in target_state_dict.keys():
if 'Qs' in key:
assert key in new_state_dict, f"key {key} not in new_state_dict"
# check that source_state_dict contains no Qs keys
for key in source_state_dict.keys():
assert 'Qs' not in key, f"key {key} contains 'Qs'"
# copy log_std_min and log_std_max from target_state_dict to new_state_dict
new_state_dict['log_std_min'] = target_state_dict['log_std_min']
new_state_dict['log_std_dif'] = target_state_dict['log_std_dif']
if '_action_masks' in target_state_dict:
new_state_dict['_action_masks'] = target_state_dict['_action_masks']
# copy new_state_dict to source_state_dict
source_state_dict.update(new_state_dict)
return source_state_dict

View File

@@ -16,7 +16,7 @@ CONSOLE_FORMAT = [
("step", "I", "int"), ("step", "I", "int"),
("episode_reward", "R", "float"), ("episode_reward", "R", "float"),
("episode_success", "S", "float"), ("episode_success", "S", "float"),
("total_time", "T", "time"), ("elapsed_time", "T", "time"),
] ]
CAT_TO_COLOR = { CAT_TO_COLOR = {

View File

@@ -1,5 +1,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from tensordict import TensorDict
def soft_ce(pred, target, cfg): def soft_ce(pred, target, cfg):
@@ -13,32 +14,19 @@ def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1) return low + 0.5 * dif * (torch.tanh(x) + 1)
def _gaussian_residual(eps, log_std): def gaussian_logprob(eps, log_std):
return -0.5 * eps.pow(2) - log_std
def _gaussian_logprob(residual):
log2pi = 1.8378770351409912
return residual - 0.5 * log2pi
def gaussian_logprob(eps, log_std, size=None):
"""Compute Gaussian log probability.""" """Compute Gaussian log probability."""
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) residual = -0.5 * eps.pow(2) - log_std
if size is None: log_prob = residual - 0.9189385175704956
size = eps.shape[-1] return log_prob.sum(-1, keepdim=True)
return _gaussian_logprob(residual) * size
def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
def squash(mu, pi, log_pi): def squash(mu, pi, log_pi):
"""Apply squashing function.""" """Apply squashing function."""
mu = torch.tanh(mu) mu = torch.tanh(mu)
pi = torch.tanh(pi) pi = torch.tanh(pi)
log_pi -= _squash(pi).sum(-1, keepdim=True) squashed_pi = torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
log_pi = log_pi - squashed_pi.sum(-1, keepdim=True)
return mu, pi, log_pi return mu, pi, log_pi
@@ -96,12 +84,24 @@ def two_hot_inv(x, cfg):
return symexp(x) return symexp(x)
def gumbel_softmax_sample(p, temperature=1.0, dim=0): def gumbel_softmax_sample(p, temperature=1.0, dim=1):
logits = p.log() """Sample indices from a Gumbel-Softmax distribution."""
# Generate Gumbel noise logits = torch.log(p + 1e-9)
gumbels = ( gumbels = -torch.empty_like(logits).exponential_().log()
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() y = (logits + gumbels) / temperature
) # ~Gumbel(0,1) return y.argmax(dim=dim)
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
return y_soft.argmax(-1) def termination_statistics(pred, target, eps=1e-9):
"""Compute episode termination statistics."""
pred = pred.squeeze(-1)
target = target.squeeze(-1)
rate = target.sum() / len(target)
tp = ((pred > 0.5) & (target == 1)).sum()
fn = ((pred <= 0.5) & (target == 1)).sum()
fp = ((pred > 0.5) & (target == 0)).sum()
recall = tp / (tp + fn + eps)
precision = tp / (tp + fp + eps)
f1 = 2 * (precision * recall) / (precision + recall + eps)
return TensorDict({'termination_rate': rate,
'termination_f1': f1})

View File

@@ -77,9 +77,8 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
cfg.task_dim = 0 cfg.task_dim = 0
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
# Check torch.compile compatibility # Ensure that eval_episodes is divisible by num_envs and is at least 1*num_envs
if cfg.get('compile', False): cfg.eval_episodes = max(cfg.eval_episodes, cfg.num_envs)
assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.' cfg.eval_episodes = cfg.eval_episodes - (cfg.eval_episodes % cfg.num_envs)
assert not cfg.multitask, 'torch.compile does not support multitask training at the moment.'
return cfg_to_dataclass(cfg) return cfg_to_dataclass(cfg)

View File

@@ -1,14 +1,15 @@
import torch import torch
from torch.nn import Buffer from torch.nn import Buffer
class RunningScale(torch.nn.Module): class RunningScale(torch.nn.Module):
"""Running trimmed scale estimator.""" """Running trimmed scale estimator."""
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))) self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda:0')))
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))) self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda:0')))
def state_dict(self): def state_dict(self):
return dict(value=self.value, percentiles=self._percentiles) return dict(value=self.value, percentiles=self._percentiles)

View File

@@ -2,11 +2,10 @@ from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from tensordict.nn import TensorDictParams
from common import layers, math, init from common import layers, math, init
from tensordict import TensorDict
from tensordict.nn import TensorDictParams
class WorldModel(nn.Module): class WorldModel(nn.Module):
@@ -26,7 +25,8 @@ class WorldModel(nn.Module):
self._encoder = layers.enc(cfg) self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action_space == 'continuous' else cfg.action_dim) self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) if cfg.episodic else None
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init) self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]) init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
@@ -46,13 +46,18 @@ class WorldModel(nn.Module):
self._target_Qs = deepcopy(self._Qs) self._target_Qs = deepcopy(self._Qs)
# Assign params to modules # Assign params to modules
self._detach_Qs.params = self._detach_Qs_params # We do this strange assignment to avoid having duplicated tensors in the state-dict -- working on a better API for this
self._target_Qs.params = self._target_Qs_params delattr(self._detach_Qs, "params")
self._detach_Qs.__dict__["params"] = self._detach_Qs_params
delattr(self._target_Qs, "params")
self._target_Qs.__dict__["params"] = self._target_Qs_params
def __repr__(self): def __repr__(self):
repr = 'TD-MPC2 World Model\n' repr = 'TD-MPC2 World Model\n'
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions'] modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]): for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
if m == self._termination and not self.cfg.episodic:
continue
repr += f"{modules[i]}: {m}\n" repr += f"{modules[i]}: {m}\n"
repr += "Learnable parameters: {:,}".format(self.total_params) repr += "Learnable parameters: {:,}".format(self.total_params)
return repr return repr
@@ -124,65 +129,59 @@ class WorldModel(nn.Module):
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
return self._reward(z) return self._reward(z)
def _continuous_pi(self, z, task): def termination(self, z, task, unnormalized=False):
"""
Predicts termination signal.
"""
assert task is None
if self.cfg.multitask:
z = self.task_emb(z, task)
if unnormalized:
return self._termination(z)
return torch.sigmoid(self._termination(z))
def pi(self, z, task):
""" """
Samples an action from the policy prior. Samples an action from the policy prior.
The policy prior is a Gaussian distribution with The policy prior is a Gaussian distribution with
mean and (log) std predicted by a neural network. mean and (log) std predicted by a neural network.
""" """
if self.cfg.multitask:
z = self.task_emb(z, task)
# Gaussian policy prior # Gaussian policy prior
mu, log_std = self._pi(z).chunk(2, dim=-1) mean, log_std = self._pi(z).chunk(2, dim=-1)
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif) log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
eps = torch.randn_like(mu) eps = torch.randn_like(mean)
if self.cfg.multitask: # Mask out unused action dimensions if self.cfg.multitask: # Mask out unused action dimensions
mu = mu * self._action_masks[task] mean = mean * self._action_masks[task]
log_std = log_std * self._action_masks[task] log_std = log_std * self._action_masks[task]
eps = eps * self._action_masks[task] eps = eps * self._action_masks[task]
action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1) action_dims = self._action_masks.sum(-1)[task].unsqueeze(-1)
else: # No masking else: # No masking
action_dims = None action_dims = None
log_pi = math.gaussian_logprob(eps, log_std, size=action_dims) log_prob = math.gaussian_logprob(eps, log_std)
pi = mu + eps * log_std.exp()
mu, pi, log_pi = math.squash(mu, pi, log_pi)
return mu, pi, log_pi, log_std # Scale log probability by action dimensions
size = eps.shape[-1] if action_dims is None else action_dims
scaled_log_prob = log_prob * size
def _discrete_pi(self, z, task): # Reparameterization trick
""" action = mean + eps * log_std.exp()
Samples an action from the policy prior. mean, action, log_prob = math.squash(mean, action, log_prob)
The policy prior is a categorical distribution
with logits predicted by a neural network.
"""
assert task is None, "Discrete policy does not support multitask."
# Categorical policy prior entropy_scale = scaled_log_prob / (log_prob + 1e-8)
logits = self._pi(z) info = TensorDict({
policy_dist = Categorical(logits=logits) "mean": mean,
"log_std": log_std,
action = policy_dist.sample() "action_prob": 1.,
action_probs = policy_dist.probs "entropy": -log_prob,
log_prob = F.log_softmax(logits, dim=-1) "scaled_entropy": -log_prob * entropy_scale,
})
one_hot_action = math.int_to_one_hot(action, self.cfg.action_dim) return action, info
return action, one_hot_action, log_prob, action_probs
def pi(self, z, task):
"""
Samples an action from the policy prior.
Policy can be either continuous (Gaussian) or discrete (categorical).
"""
if self.cfg.multitask:
z = self.task_emb(z, task)
if self.cfg.action_space == 'discrete':
return self._discrete_pi(z, task)
elif self.cfg.action_space == 'continuous':
return self._continuous_pi(z, task)
else:
raise NotImplementedError(f"Action space {self.cfg.action} not supported.")
def Q(self, z, a, task, return_type='min', target=False, detach=False): def Q(self, z, a, task, return_type='min', target=False, detach=False):
""" """

View File

@@ -2,8 +2,10 @@ defaults:
- override hydra/launcher: submitit_local - override hydra/launcher: submitit_local
# environment # environment
task: discrete-cartpole-swingup task: dog-run
obs: state obs: state
episodic: false
num_envs: 1
# evaluation # evaluation
checkpoint: ??? checkpoint: ???
@@ -13,8 +15,10 @@ eval_freq: 50000
# training # training
steps: 10_000_000 steps: 10_000_000
batch_size: 256 batch_size: 256
steps_per_update: 1
reward_coef: 0.1 reward_coef: 0.1
value_coef: 0.1 value_coef: 0.1
termination_coef: 1
consistency_coef: 20 consistency_coef: 20
rho: 0.5 rho: 0.5
lr: 3e-4 lr: 3e-4
@@ -62,13 +66,14 @@ dropout: 0.01
simnorm_dim: 8 simnorm_dim: 8
# logging # logging
wandb_project: ??? wandb_project: tdmpc3
wandb_entity: ??? wandb_entity: nicklashansen
wandb_silent: false wandb_silent: false
enable_wandb: true enable_wandb: true
save_csv: true save_csv: true
# misc # misc
compile: true
save_video: true save_video: true
save_agent: true save_agent: true
seed: 1 seed: 1
@@ -79,7 +84,6 @@ task_title: ???
multitask: ??? multitask: ???
tasks: ??? tasks: ???
obs_shape: ??? obs_shape: ???
action_space: ???
action_dim: ??? action_dim: ???
episode_length: ??? episode_length: ???
obs_shapes: ??? obs_shapes: ???
@@ -87,6 +91,3 @@ action_dims: ???
episode_lengths: ??? episode_lengths: ???
seed_steps: ??? seed_steps: ???
bin_size: ??? bin_size: ???
# speedups
compile: False

View File

@@ -1,12 +1,12 @@
from copy import deepcopy from copy import deepcopy
import warnings import warnings
import gym import gymnasium as gym
from envs.wrappers.discrete import DiscreteWrapper
from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper from envs.wrappers.tensor import TensorWrapper
from envs.wrappers.vectorized import Vectorized
def missing_dependencies(task): def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
@@ -27,6 +27,10 @@ try:
from envs.myosuite import make_env as make_myosuite_env from envs.myosuite import make_env as make_myosuite_env
except: except:
make_myosuite_env = missing_dependencies make_myosuite_env = missing_dependencies
try:
from envs.mujoco import make_env as make_mujoco_env
except:
make_mujoco_env = missing_dependencies
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -62,12 +66,7 @@ def make_env(cfg):
env = make_multitask_env(cfg) env = make_multitask_env(cfg)
else: else:
env = None env = None
if cfg.task.startswith('discrete-'): for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]:
discrete = True
cfg.task = cfg.task.replace('discrete-', '')
else:
discrete = False
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try: try:
env = fn(cfg) env = fn(cfg)
break break
@@ -75,19 +74,15 @@ def make_env(cfg):
pass pass
if env is None: if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
assert cfg.num_envs == 1 or cfg.get('obs', 'state') == 'state', \
'Vectorized environments only support state observations.'
env = Vectorized(cfg, fn)
env = TensorWrapper(env) env = TensorWrapper(env)
if discrete:
env = DiscreteWrapper(env)
if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env)
try: # Dict try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
except: # Box except: # Box
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
assert not isinstance(env.action_space, (gym.spaces.Dict, gym.spaces.MultiDiscrete)), \ cfg.action_dim = env.action_space.shape[0]
'Dict and MultiDiscrete action spaces are not supported.'
cfg.action_space = 'discrete' if isinstance(env.action_space, gym.spaces.Discrete) else 'continuous'
cfg.action_dim = env.action_space.n if cfg.action_space == 'discrete' else env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length) cfg.seed_steps = max(1000, 5*cfg.episode_length) * cfg.num_envs
return env return env

View File

@@ -1,181 +1,102 @@
from collections import defaultdict from collections import defaultdict, deque
from typing import Any, NamedTuple
import dm_env import gymnasium as gym
import numpy as np import numpy as np
import torch
from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish from envs.tasks import cheetah, walker, hopper, reacher, ball_in_cup, pendulum, fish
from dm_control import suite from dm_control import suite
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom') suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS) suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale from dm_control.suite.wrappers import action_scale
from dm_env import StepType, specs
import gym from envs.wrappers.timeout import Timeout
class ExtendedTimeStep(NamedTuple): def get_obs_shape(env):
step_type: Any obs_shp = []
reward: Any for v in env.observation_spec().values():
discount: Any try:
observation: Any shp = np.prod(v.shape)
action: Any except:
shp = 1
def first(self): obs_shp.append(shp)
return self.step_type == StepType.FIRST return (int(np.sum(obs_shp)),)
def mid(self):
return self.step_type == StepType.MID
def last(self):
return self.step_type == StepType.LAST
class ActionRepeatWrapper(dm_env.Environment): class DMControlWrapper:
def __init__(self, env, num_repeats): def __init__(self, env, domain):
self._env = env
self._num_repeats = num_repeats
def step(self, action):
reward = 0.0
discount = 1.0
for i in range(self._num_repeats):
time_step = self._env.step(action)
reward += (time_step.reward or 0.0) * discount
discount *= time_step.discount
if time_step.last():
break
return time_step._replace(reward=reward, discount=discount)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class ActionDTypeWrapper(dm_env.Environment):
def __init__(self, env, dtype):
self._env = env
wrapped_action_spec = env.action_spec()
self._action_spec = specs.BoundedArray(wrapped_action_spec.shape,
dtype,
wrapped_action_spec.minimum,
wrapped_action_spec.maximum,
'action')
def step(self, action):
action = action.astype(self._env.action_spec().dtype)
return self._env.step(action)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._action_spec
def reset(self):
return self._env.reset()
def __getattr__(self, name):
return getattr(self._env, name)
class ExtendedTimeStepWrapper(dm_env.Environment):
def __init__(self, env):
self._env = env
def reset(self):
time_step = self._env.reset()
return self._augment_time_step(time_step)
def step(self, action):
time_step = self._env.step(action)
return self._augment_time_step(time_step, action)
def _augment_time_step(self, time_step, action=None):
if action is None:
action_spec = self.action_spec()
action = np.zeros(action_spec.shape, dtype=action_spec.dtype)
return ExtendedTimeStep(observation=time_step.observation,
step_type=time_step.step_type,
action=action,
reward=time_step.reward or 0.0,
discount=time_step.discount or 1.0)
def observation_spec(self):
return self._env.observation_spec()
def action_spec(self):
return self._env.action_spec()
def __getattr__(self, name):
return getattr(self._env, name)
class TimeStepToGymWrapper:
def __init__(self, env, domain, task):
obs_shp = []
for v in env.observation_spec().values():
try:
shp = np.prod(v.shape)
except:
shp = 1
obs_shp.append(shp)
obs_shp = (int(np.sum(obs_shp)),)
act_shp = env.action_spec().shape
self.observation_space = gym.spaces.Box(
low=np.full(
obs_shp,
-np.inf,
dtype=np.float32),
high=np.full(
obs_shp,
np.inf,
dtype=np.float32),
dtype=np.float32,
)
self.action_space = gym.spaces.Box(
low=np.full(act_shp, env.action_spec().minimum),
high=np.full(act_shp, env.action_spec().maximum),
dtype=env.action_spec().dtype)
self.env = env self.env = env
self.domain = domain self.camera_id = 2 if domain == 'quadruped' else 0
self.task = task obs_shape = get_obs_shape(env)
self.max_episode_steps = 500 action_shape = env.action_spec().shape
self.t = 0 self.observation_space = gym.spaces.Box(
low=np.full(obs_shape, -np.inf, dtype=np.float32),
high=np.full(obs_shape, np.inf, dtype=np.float32),
dtype=np.float32)
self.action_space = gym.spaces.Box(
low=np.full(action_shape, env.action_spec().minimum),
high=np.full(action_shape, env.action_spec().maximum),
dtype=env.action_spec().dtype)
self.action_spec_dtype = env.action_spec().dtype
@property @property
def unwrapped(self): def unwrapped(self):
return self.env return self.env
@property
def reward_range(self):
return None
@property @property
def metadata(self): def metadata(self):
return None return None
def _obs_to_array(self, obs): def _obs_to_array(self, obs):
return np.concatenate([v.flatten() for v in obs.values()]) return torch.from_numpy(
np.concatenate([v.flatten() for v in obs.values()], dtype=np.float32))
def reset(self): def reset(self):
self.t = 0 return self._obs_to_array(self.env.reset().observation), defaultdict(float)
return self._obs_to_array(self.env.reset().observation)
def step(self, action): def step(self, action):
self.t += 1 reward = 0
time_step = self.env.step(action) action = action.astype(self.action_spec_dtype)
return self._obs_to_array(time_step.observation), time_step.reward, time_step.last() or self.t == self.max_episode_steps, defaultdict(float) for _ in range(2):
step = self.env.step(action)
reward += step.reward
return self._obs_to_array(step.observation), reward, False, defaultdict(float)
def render(self, mode='rgb_array', width=384, height=384, camera_id=0): def render(self, width=384, height=384, camera_id=None):
camera_id = dict(quadruped=2).get(self.domain, camera_id) return self.env.physics.render(height, width, camera_id or self.camera_id)
return self.env.physics.render(height, width, camera_id)
def close(self):
self.env.close()
class Pixels(gym.Wrapper):
def __init__(self, env, cfg, num_frames=3, size=64):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(num_frames*3, size, size), dtype=np.uint8)
self._frames = deque([], maxlen=num_frames)
self._size = size
def _get_obs(self, is_reset=False):
frame = self.env.render(width=self._size, height=self._size).transpose(2, 0, 1)
num_frames = self._frames.maxlen if is_reset else 1
for _ in range(num_frames):
self._frames.append(frame)
return torch.from_numpy(np.concatenate(self._frames))
def reset(self):
self.env.reset()
return self._get_obs(is_reset=True)
def step(self, action):
_, reward, done, info = self.env.step(action)
return self._get_obs(), reward, done, info
def close(self):
self.env.close()
def make_env(cfg): def make_env(cfg):
@@ -192,9 +113,9 @@ def make_env(cfg):
task, task,
task_kwargs={'random': cfg.seed}, task_kwargs={'random': cfg.seed},
visualize_reward=False) visualize_reward=False)
env = ActionDTypeWrapper(env, np.float32)
env = ActionRepeatWrapper(env, 2)
env = action_scale.Wrapper(env, minimum=-1., maximum=1.) env = action_scale.Wrapper(env, minimum=-1., maximum=1.)
env = ExtendedTimeStepWrapper(env) env = DMControlWrapper(env, domain)
env = TimeStepToGymWrapper(env, domain, task) if cfg.obs == 'rgb':
env = Pixels(env, cfg)
env = Timeout(env, max_episode_steps=500)
return env return env

View File

@@ -1,6 +1,6 @@
import gym import gymnasium as gym
import numpy as np import numpy as np
from envs.wrappers.time_limit import TimeLimit from envs.wrappers.timeout import Timeout
import mani_skill2.envs import mani_skill2.envs
@@ -47,9 +47,12 @@ class ManiSkillWrapper(gym.Wrapper):
def step(self, action): def step(self, action):
reward = 0 reward = 0
for _ in range(2): for _ in range(2):
obs, r, _, info = self.env.step(action) obs, r, done, info = self.env.step(action)
reward += r reward += r
return obs, reward, False, info info['terminated'] = done
if done:
break
return obs, reward, done, info
@property @property
def unwrapped(self): def unwrapped(self):
@@ -74,6 +77,6 @@ def make_env(cfg):
render_camera_cfgs=dict(width=384, height=384), render_camera_cfgs=dict(width=384, height=384),
) )
env = ManiSkillWrapper(env, cfg) env = ManiSkillWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100) env = Timeout(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps env.max_episode_steps = env._max_episode_steps
return env return env

View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
import gym import gym
from envs.wrappers.time_limit import TimeLimit from envs.wrappers.timeout import Timeout
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
@@ -47,6 +47,6 @@ def make_env(cfg):
assert cfg.obs == 'state', 'This task only supports state observations.' assert cfg.obs == 'state', 'This task only supports state observations.'
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed) env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id](seed=cfg.seed)
env = MetaWorldWrapper(env, cfg) env = MetaWorldWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100) env = Timeout(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps env.max_episode_steps = env._max_episode_steps
return env return env

59
tdmpc2/envs/mujoco.py Normal file
View File

@@ -0,0 +1,59 @@
import numpy as np
import gymnasium as gym
from envs.wrappers.timeout import Timeout
MUJOCO_TASKS = {
'mujoco-walker': 'Walker2d-v4',
'mujoco-halfcheetah': 'HalfCheetah-v4',
'bipedal-walker': 'BipedalWalker-v3',
'lunarlander-continuous': 'LunarLander-v2',
}
class MuJoCoWrapper(gym.Wrapper):
def __init__(self, env, cfg):
super().__init__(env)
self.env = env
self.cfg = cfg
self._cumulative_reward = 0
def reset(self):
self._cumulative_reward = 0
return self.env.reset()[0]
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action.copy())
self._cumulative_reward += reward
done = terminated or truncated
info['terminated'] = terminated
if self.cfg.task == 'lunarlander-continuous':
info['success'] = self._cumulative_reward > 200
return obs, reward, done, info
@property
def unwrapped(self):
return self.env.unwrapped
def render(self, **kwargs):
return self.env.render(**kwargs)
def make_env(cfg):
"""
Make classic/MuJoCo environment.
"""
if not cfg.task in MUJOCO_TASKS:
raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
if cfg.task == 'lunarlander-continuous':
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
else:
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
env = MuJoCoWrapper(env, cfg)
env = Timeout(env, max_episode_steps={
'lunarlander-continuous': 500,
'bipedal-walker': 1600,
}.get(cfg.task, 1000)) # Default max episode steps for other tasks
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
cfg.rho = 0.7 # TODO: increase rho for episodic tasks since termination always happens at the end of a sequence
return env

View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
import gym import gymnasium as gym
from envs.wrappers.time_limit import TimeLimit from envs.wrappers.timeout import Timeout
MYOSUITE_TASKS = { MYOSUITE_TASKS = {
@@ -53,6 +53,6 @@ def make_env(cfg):
from myosuite.utils import gym as gym_utils from myosuite.utils import gym as gym_utils
env = gym_utils.make(MYOSUITE_TASKS[cfg.task]) env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg) env = MyoSuiteWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100) env = Timeout(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps env.max_episode_steps = env._max_episode_steps
return env return env

View File

@@ -1,33 +0,0 @@
import gym
import torch
from common import math
class DiscreteWrapper(gym.Wrapper):
"""
Wrapper for converting continuous action spaces to discrete via binning.
"""
def __init__(self, env, bins_per_dim=5):
super().__init__(env)
self.bins_per_dim = bins_per_dim
self.continuous_dims = self.env.action_space.shape[0]
# Equally spaced bins along each dimension
self.action_space = gym.spaces.Discrete(bins_per_dim ** self.continuous_dims)
def rand_act(self):
action = torch.tensor(self.action_space.sample(), dtype=torch.int64)
return math.int_to_one_hot(action, self.action_space.n)
def _discrete_to_continuous(self, action):
# Convert a discrete action to a continuous action
action = torch.argmax(action)
action = action.item()
action = [action // self.bins_per_dim ** i % self.bins_per_dim for i in range(self.continuous_dims)]
action = torch.tensor(action, dtype=torch.float32)
return (action - 1) / 1
def step(self, action):
action = self._discrete_to_continuous(action)
return self.env.step(action)

View File

@@ -1,4 +1,4 @@
import gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch

View File

@@ -1,38 +0,0 @@
from collections import deque
import gym
import numpy as np
import torch
class PixelWrapper(gym.Wrapper):
"""
Wrapper for pixel observations. Compatible with DMControl environments.
"""
def __init__(self, cfg, env, num_frames=3, render_size=64):
super().__init__(env)
self.cfg = cfg
self.env = env
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(num_frames*3, render_size, render_size), dtype=np.uint8
)
self._frames = deque([], maxlen=num_frames)
self._render_size = render_size
def _get_obs(self):
frame = self.env.render(
mode='rgb_array', width=self._render_size, height=self._render_size
).transpose(2, 0, 1)
self._frames.append(frame)
return torch.from_numpy(np.concatenate(self._frames))
def reset(self):
self.env.reset()
for _ in range(self._frames.maxlen):
obs = self._get_obs()
return obs
def step(self, action):
_, reward, done, info = self.env.step(action)
return self._get_obs(), reward, done, info

View File

@@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
import gym import gymnasium as gym
import numpy as np import numpy as np
import torch import torch
@@ -12,14 +12,18 @@ class TensorWrapper(gym.Wrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env) super().__init__(env)
self._wrapped_vectorized = env.__class__.__name__ == 'Vectorized'
def rand_act(self): def rand_act(self):
if self._wrapped_vectorized:
return self.env.rand_act()
return torch.from_numpy(self.action_space.sample().astype(np.float32)) return torch.from_numpy(self.action_space.sample().astype(np.float32))
def _try_f32_tensor(self, x): def _try_f32_tensor(self, x):
x = torch.from_numpy(x) if isinstance(x, np.ndarray):
if x.dtype == torch.float64: x = torch.from_numpy(x)
x = x.float() if x.dtype == torch.float64:
x = x.float()
return x return x
def _obs_to_tensor(self, obs): def _obs_to_tensor(self, obs):
@@ -30,11 +34,24 @@ class TensorWrapper(gym.Wrapper):
obs = self._try_f32_tensor(obs) obs = self._try_f32_tensor(obs)
return obs return obs
def reset(self, task_idx=None): def reset(self, task_idx=None, **kwargs):
return self._obs_to_tensor(self.env.reset()) if self._wrapped_vectorized:
obs = self.env.reset(**kwargs)
else:
obs = self.env.reset()
return self._obs_to_tensor(obs)
def step(self, action): def step(self, action, **kwargs):
obs, reward, done, info = self.env.step(action.numpy()) if self._wrapped_vectorized:
info = defaultdict(float, info) obs, reward, terminated, truncated, info = self.env.step(action.numpy(), **kwargs)
info['success'] = float(info['success']) else:
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info obs, reward, terminated, truncated, info = self.env.step(action.numpy())
reward = torch.tensor(reward, dtype=torch.float32)
terminated = torch.tensor(terminated)
truncated = torch.tensor(truncated)
done = terminated | truncated
if 'success' not in info:
info['success'] = torch.zeros_like(reward)
info['terminated'] = terminated.float()
info['truncated'] = truncated.float()
return self._obs_to_tensor(obs), reward, done, info

View File

@@ -1,72 +0,0 @@
"""
Wrapper for limiting the time steps of an environment.
Source: https://github.com/openai/gym/blob/3498617bf031538a808b75b932f4ed2c11896a3e/gym/wrappers/time_limit.py
"""
from typing import Optional
import gym
class TimeLimit(gym.Wrapper):
"""This wrapper will issue a `done` signal if a maximum number of timesteps is exceeded.
Oftentimes, it is **very** important to distinguish `done` signals that were produced by the
:class:`TimeLimit` wrapper (truncations) and those that originate from the underlying environment (terminations).
This can be done by looking at the ``info`` that is returned when `done`-signal was issued.
The done-signal originates from the time limit (i.e. it signifies a *truncation*) if and only if
the key `"TimeLimit.truncated"` exists in ``info`` and the corresponding value is ``True``.
Example:
>>> from gym.envs.classic_control import CartPoleEnv
>>> from gym.wrappers import TimeLimit
>>> env = CartPoleEnv()
>>> env = TimeLimit(env, max_episode_steps=1000)
"""
def __init__(self, env: gym.Env, max_episode_steps: Optional[int] = None):
"""Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
Args:
env: The environment to apply the wrapper
max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
"""
super().__init__(env)
if max_episode_steps is None and self.env.spec is not None:
max_episode_steps = env.spec.max_episode_steps
if self.env.spec is not None:
self.env.spec.max_episode_steps = max_episode_steps
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None
def step(self, action):
"""Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
Args:
action: The environment step action
Returns:
The environment step ``(observation, reward, done, info)`` with "TimeLimit.truncated"=True
when truncated (the number of steps elapsed >= max episode steps) or
"TimeLimit.truncated"=False if the environment terminated
"""
observation, reward, done, info = self.env.step(action)
self._elapsed_steps += 1
if self._elapsed_steps >= self._max_episode_steps:
# TimeLimit.truncated key may have been already set by the environment
# do not overwrite it
episode_truncated = not done or info.get("TimeLimit.truncated", False)
info["TimeLimit.truncated"] = episode_truncated
done = True
return observation, reward, done, info
def reset(self, **kwargs):
"""Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
Args:
**kwargs: The kwargs to reset the environment with
Returns:
The reset environment
"""
self._elapsed_steps = 0
return self.env.reset(**kwargs)

View File

@@ -0,0 +1,27 @@
import gymnasium as gym
class Timeout(gym.Wrapper):
"""
Wrapper for enforcing a time limit on the environment.
"""
def __init__(self, env, max_episode_steps):
super().__init__(env)
self._max_episode_steps = max_episode_steps
@property
def max_episode_steps(self):
return self._max_episode_steps
def reset(self, **kwargs):
self._t = 0
return self.env.reset(**kwargs)
def step(self, action):
obs, reward, terminated, info = self.env.step(action)
self._t += 1
truncated = self._t >= self.max_episode_steps
info['terminated'] = terminated
info['truncated'] = truncated
return obs, reward, terminated, truncated, info

View File

@@ -0,0 +1,41 @@
from copy import deepcopy
from gymnasium.vector import AsyncVectorEnv
import numpy as np
import torch
class Vectorized():
"""
Vectorized environment for TD-MPC2 online training.
"""
def __init__(self, cfg, env_fn):
super().__init__()
self.cfg = cfg
def make():
_cfg = deepcopy(cfg)
_cfg.num_envs = 1
_cfg.seed = cfg.seed + np.random.randint(1000)
return env_fn(_cfg)
print(f'Creating {cfg.num_envs} environments...')
self.env = AsyncVectorEnv([make for _ in range(cfg.num_envs)])
env = make()
self.observation_space = env.observation_space
self.action_space = env.action_space
self.max_episode_steps = env.max_episode_steps
def rand_act(self):
return torch.rand((self.cfg.num_envs, *self.action_space.shape)) * 2 - 1
def reset(self):
obs, _ = self.env.reset()
return obs
def step(self, action):
return self.env.step(action)
def render(self, *args, **kwargs):
return self.env.render(*args, **kwargs)

View File

@@ -1,5 +1,5 @@
import os import os
os.environ['MUJOCO_GL'] = 'egl' os.environ['MUJOCO_GL'] = os.getenv("MUJOCO_GL", 'egl')
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')

View File

@@ -4,6 +4,7 @@ import torch.nn.functional as F
from common import math from common import math
from common.scale import RunningScale from common.scale import RunningScale
from common.world_model import WorldModel from common.world_model import WorldModel
from common.layers import api_model_conversion
from tensordict import TensorDict from tensordict import TensorDict
@@ -23,6 +24,7 @@ class TDMPC2(torch.nn.Module):
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()}, {'params': self.model._dynamics.parameters()},
{'params': self.model._reward.parameters()}, {'params': self.model._reward.parameters()},
{'params': self.model._termination.parameters() if self.cfg.episodic else []},
{'params': self.model._Qs.parameters()}, {'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else [] {'params': self.model._task_emb.parameters() if self.cfg.multitask else []
} }
@@ -34,7 +36,9 @@ class TDMPC2(torch.nn.Module):
self.discount = torch.tensor( self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) print('Episode length:', cfg.episode_length)
print('Discount factor:', self.discount)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile: if cfg.compile:
print('Compiling update function with torch.compile...') print('Compiling update function with torch.compile...')
self._update = torch.compile(self._update, mode="reduce-overhead") self._update = torch.compile(self._update, mode="reduce-overhead")
@@ -82,8 +86,14 @@ class TDMPC2(torch.nn.Module):
Args: Args:
fp (str or dict): Filepath or state dict to load. fp (str or dict): Filepath or state dict to load.
""" """
state_dict = fp if isinstance(fp, dict) else torch.load(fp) if isinstance(fp, dict):
self.model.load_state_dict(state_dict["model"]) state_dict = fp
else:
state_dict = torch.load(fp, map_location=torch.get_default_device(), weights_only=False)
state_dict = state_dict["model"] if "model" in state_dict else state_dict
state_dict = api_model_conversion(self.model.state_dict(), state_dict)
self.model.load_state_dict(state_dict)
return
@torch.no_grad() @torch.no_grad()
def act(self, obs, t0=False, eval_mode=False, task=None): def act(self, obs, t0=False, eval_mode=False, task=None):
@@ -99,58 +109,32 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
obs = obs.to(self.device, non_blocking=True).unsqueeze(0) obs = obs.to(self.device, non_blocking=True)
if task is not None: if task is not None:
task = torch.tensor([task], device=self.device) task = torch.tensor([task], device=self.device)
if self.cfg.mpc: if self.cfg.mpc:
action = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) return self.plan(obs, t0=t0, eval_mode=eval_mode, task=task).cpu()
else: z = self.model.encode(obs, task)
z = self.model.encode(obs, task) action, info = self.model.pi(z, task)
select_idx = int(not eval_mode or self.cfg.action_space == 'discrete') if eval_mode:
action = self.model.pi(z, task)[select_idx][0] action = info["mean"]
return action.cpu() return action.cpu()
@torch.no_grad() @torch.no_grad()
def _estimate_value(self, z, actions, task): def _estimate_value(self, z, actions, task):
"""Estimate value of a trajectory starting at latent state z and executing given actions.""" """Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1 G, discount = 0, 1
termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) reward = math.two_hot_inv(self.model.reward(z, actions[:, t], task), self.cfg)
z = self.model.next(z, actions[t], task) z = self.model.next(z, actions[:, t], task)
G = G + discount * reward G = G + discount * (1-termination) * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update discount = discount * discount_update
pi = self.model.pi(z, task)[1] if self.cfg.episodic:
if self.cfg.action_space == 'discrete': termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.)
pi = pi.squeeze(1) # TODO: this is a bit hacky action, _ = self.model.pi(z, task)
return G + discount * self.model.Q(z, pi, task, return_type='avg') return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')
@torch.no_grad()
def _sample_policy(self, z, task):
"""Sample trajectories from the policy prior."""
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
for t in range(self.cfg.horizon-1):
action = self.model.pi(z, task)[1]
if self.cfg.action_space == 'discrete':
action = action.squeeze(1)
pi_actions[t] = action
z = self.model.next(z, pi_actions[t], task)
action = self.model.pi(z, task)[1]
if self.cfg.action_space == 'discrete':
action = action.squeeze(1)
pi_actions[-1] = action
return pi_actions
@torch.no_grad()
def _sample_actions(self, n, mean=None, std=None):
"""Sample actions from a Gaussian or Categorical distribution."""
if self.cfg.action_space == 'discrete':
actions = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, n), device=self.device)
actions = math.int_to_one_hot(actions, self.cfg.action_dim)
else:
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
actions = (mean.unsqueeze(1) + std.unsqueeze(1) * r).clamp(-1, 1)
return actions
@torch.no_grad() @torch.no_grad()
def _plan(self, obs, t0=False, eval_mode=False, task=None): def _plan(self, obs, t0=False, eval_mode=False, task=None):
@@ -158,7 +142,7 @@ class TDMPC2(torch.nn.Module):
Plan a sequence of actions using the learned world model. Plan a sequence of actions using the learned world model.
Args: Args:
z (torch.Tensor): Latent state from which to plan. obs (torch.Tensor): Observation from which to plan.
t0 (bool): Whether this is the first observation in the episode. t0 (bool): Whether this is the first observation in the episode.
eval_mode (bool): Whether to use the mean of the action distribution. eval_mode (bool): Whether to use the mean of the action distribution.
task (Torch.Tensor): Task index (only used for multi-task experiments). task (Torch.Tensor): Task index (only used for multi-task experiments).
@@ -166,61 +150,71 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
# Encode observation
z = self.model.encode(obs, task) z = self.model.encode(obs, task)
z = z.repeat(self.cfg.num_samples, 1)
# Initialize parameters
if self.cfg.action_space == 'continuous':
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
if not t0:
mean[:-1] = self._prev_mean[1:]
else:
mean, std = None, None
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
# Sample policy trajectories # Sample policy trajectories
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = self._sample_policy(z[:self.cfg.num_pi_trajs], task) pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.unsqueeze(1).repeat(1, self.cfg.num_pi_trajs, 1).view(self.cfg.num_envs * self.cfg.num_pi_trajs, -1)
for t in range(self.cfg.horizon - 1):
a, _ = self.model.pi(_z, task)
pi_actions[:, t] = a.view(self.cfg.num_envs, self.cfg.num_pi_trajs, self.cfg.action_dim)
_z = self.model.next(_z, a, task)
a, _ = self.model.pi(_z, task)
pi_actions[:, -1] = a.view(self.cfg.num_envs, self.cfg.num_pi_trajs, self.cfg.action_dim)
# Initialize state and parameters
z = z.unsqueeze(1).repeat(1, self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = torch.full((self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, device=self.device)
if not t0:
mean[:, :-1] = self._prev_mean[:, 1:]
actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
if self.cfg.num_pi_trajs > 0:
actions[:, :, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI # Iterate MPPI
for _ in range(self.cfg.iterations): for _ in range(self.cfg.iterations):
# Sample random actions # Sample new actions
actions[:, self.cfg.num_pi_trajs:] = self._sample_actions(self.cfg.num_samples-self.cfg.num_pi_trajs, mean, std) r = torch.randn(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples - self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
actions_sample = mean.unsqueeze(2) + std.unsqueeze(2) * r
actions[:, :, self.cfg.num_pi_trajs:] = actions_sample.clamp(-1, 1)
if self.cfg.multitask: if self.cfg.multitask:
actions = actions * self.model._action_masks[task] actions = actions * self.model._action_masks[task]
# Select elites and compute scores # Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num(0) value = self._estimate_value(z, actions, task).nan_to_num(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
max_value = elite_value.max(0).values elite_actions = actions.gather(
score = torch.exp(self.cfg.temperature*(elite_value - max_value)) dim=2,
score = score / score.sum(0) index=elite_idxs[:, None, :, None].expand(-1, self.cfg.horizon, self.cfg.num_elites, self.cfg.action_dim)
)
# Update parameters # Update parameters
if self.cfg.action_space == 'continuous': score = torch.exp(self.cfg.temperature * (elite_value - elite_value.max(1, keepdim=True).values))
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9) score = score / (score.sum(dim=1, keepdim=True) + 1e-9)
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt() score_exp = score.unsqueeze(1)
std = std.clamp(self.cfg.min_std, self.cfg.max_std) mean = (score_exp * elite_actions).sum(dim=2) / (score_exp.sum(dim=2) + 1e-9)
if self.cfg.multitask: std = ((score_exp * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) /
mean = mean * self.model._action_masks[task] (score_exp.sum(dim=2) + 1e-9)).sqrt().clamp(self.cfg.min_std, self.cfg.max_std)
std = std * self.model._action_masks[task] if self.cfg.multitask:
else: mean = mean * self.model._action_masks[task]
break std = std * self.model._action_masks[task]
# Select action # Select action
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs logits = torch.log(score.squeeze(2) + 1e-9)
action = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)[0] rand_idx = math.gumbel_softmax_sample(logits, temperature=self.cfg.temperature, dim=1)
if self.cfg.action_space == 'continuous': selected_actions = elite_actions.gather(
if not eval_mode: dim=2,
action = action + std[0] * torch.randn(self.cfg.action_dim, device=std.device) index=rand_idx[:, None, None, None].expand(-1, self.cfg.horizon, 1, self.cfg.action_dim)
self._prev_mean.copy_(mean) ).squeeze(2)
action = action.clamp(-1, 1) action, std_out = selected_actions[:, 0], std[:, 0]
if not eval_mode:
return action action = action + std_out * torch.randn_like(action)
self._prev_mean.copy_(mean)
return action.clamp(-1, 1)
def update_pi(self, zs, task): def update_pi(self, zs, task):
""" """
@@ -233,59 +227,51 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
float: Loss of the policy update. float: Loss of the policy update.
""" """
_, actions, log_probs, action_probs = self.model.pi(zs, task) action, info = self.model.pi(zs, task)
qs = self.model.Q(zs, action, task, return_type='avg', detach=True)
if self.cfg.action_space == 'discrete': self.scale.update(qs[0])
actions = torch.eye(self.cfg.action_dim, device=zs.device).unsqueeze(0)
zs = zs.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)
actions = actions.unsqueeze(0).repeat(zs.shape[0], zs.shape[1], 1, 1)
qs = self.model.Q(zs, actions, task, return_type='avg', detach=True)
if self.cfg.action_space == 'discrete':
qs = qs.squeeze(-1)
self.scale.update(torch.sum(action_probs*qs,dim=(1,2),keepdim=True)[0])
else:
self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
# Loss is a weighted sum of Q-values # Loss is a weighted sum of Q-values
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
if self.cfg.action_space == 'discrete': pi_loss = (-(self.cfg.entropy_coef * info["scaled_entropy"] + qs).mean(dim=(1,2)) * rho).mean()
pi_loss = ((action_probs * (self.cfg.entropy_coef * log_probs - qs)).mean(dim=(1,2)) * rho).mean()
else:
pi_loss = ((self.cfg.entropy_coef * log_probs - qs).mean(dim=(1,2)) * rho).mean()
pi_loss.backward() pi_loss.backward()
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step() self.pi_optim.step()
self.pi_optim.zero_grad(set_to_none=True) self.pi_optim.zero_grad(set_to_none=True)
return pi_loss.detach(), pi_grad_norm info = TensorDict({
"pi_loss": pi_loss,
"pi_grad_norm": pi_grad_norm,
"pi_entropy": info["entropy"],
"pi_scaled_entropy": info["scaled_entropy"],
"pi_scale": self.scale.value,
})
return info
@torch.no_grad() @torch.no_grad()
def _td_target(self, next_z, reward, task): def _td_target(self, next_z, reward, terminated, task):
""" """
Compute the TD-target from a reward and the observation at the following time step. Compute the TD-target from a reward and the observation at the following time step.
Args: Args:
next_z (torch.Tensor): Latent state at the following time step. next_z (torch.Tensor): Latent state at the following time step.
reward (torch.Tensor): Reward at the current time step. reward (torch.Tensor): Reward at the current time step.
terminated (torch.Tensor): Termination signal at the current time step.
task (torch.Tensor): Task index (only used for multi-task experiments). task (torch.Tensor): Task index (only used for multi-task experiments).
Returns: Returns:
torch.Tensor: TD-target. torch.Tensor: TD-target.
""" """
pi = self.model.pi(next_z, task)[1] action, _ = self.model.pi(next_z, task)
if self.cfg.action_space == 'discrete':
pi = pi.squeeze(2) # TODO: this is a bit hacky
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
def _update(self, obs, action, reward, task=None): def _update(self, obs, action, reward, terminated, task=None):
# Compute targets # Compute targets
with torch.no_grad(): with torch.no_grad():
next_z = self.model.encode(obs[1:], task) next_z = self.model.encode(obs[1:], task)
td_targets = self._td_target(next_z, reward, task) td_targets = self._td_target(next_z, reward, terminated, task)
# Prepare for update # Prepare for update
self.model.train() self.model.train()
@@ -304,6 +290,8 @@ class TDMPC2(torch.nn.Module):
_zs = zs[:-1] _zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all') qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task) reward_preds = self.model.reward(_zs, action, task)
if self.cfg.episodic:
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
# Compute losses # Compute losses
reward_loss, value_loss = 0, 0 reward_loss, value_loss = 0, 0
@@ -314,10 +302,15 @@ class TDMPC2(torch.nn.Module):
consistency_loss = consistency_loss / self.cfg.horizon consistency_loss = consistency_loss / self.cfg.horizon
reward_loss = reward_loss / self.cfg.horizon reward_loss = reward_loss / self.cfg.horizon
if self.cfg.episodic:
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
else:
termination_loss = 0.
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q) value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
total_loss = ( total_loss = (
self.cfg.consistency_coef * consistency_loss + self.cfg.consistency_coef * consistency_loss +
self.cfg.reward_coef * reward_loss + self.cfg.reward_coef * reward_loss +
self.cfg.termination_coef * termination_loss +
self.cfg.value_coef * value_loss self.cfg.value_coef * value_loss
) )
@@ -328,23 +321,25 @@ class TDMPC2(torch.nn.Module):
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
# Update policy # Update policy
pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task) pi_info = self.update_pi(zs.detach(), task)
# Update target Q-functions # Update target Q-functions
self.model.soft_update_target_Q() self.model.soft_update_target_Q()
# Return training statistics # Return training statistics
self.model.eval() self.model.eval()
return TensorDict({ info = TensorDict({
"consistency_loss": consistency_loss, "consistency_loss": consistency_loss,
"reward_loss": reward_loss, "reward_loss": reward_loss,
"value_loss": value_loss, "value_loss": value_loss,
"pi_loss": pi_loss, "termination_loss": termination_loss,
"total_loss": total_loss, "total_loss": total_loss,
"grad_norm": grad_norm, "grad_norm": grad_norm,
"pi_grad_norm": pi_grad_norm, })
"pi_scale": self.scale.value, if self.cfg.episodic:
}).detach().mean() info.update(math.termination_statistics(torch.sigmoid(termination_pred[-1]), terminated[-1]))
info.update(pi_info)
return info.detach().mean()
def update(self, buffer): def update(self, buffer):
""" """
@@ -356,9 +351,9 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
dict: Dictionary of training statistics. dict: Dictionary of training statistics.
""" """
obs, action, reward, task = buffer.sample() obs, action, reward, terminated, task = buffer.sample()
kwargs = {} kwargs = {}
if task is not None: if task is not None:
kwargs["task"] = task kwargs["task"] = task
torch.compiler.cudagraph_mark_step_begin() torch.compiler.cudagraph_mark_step_begin()
return self._update(obs, action, reward, **kwargs) return self._update(obs, action, reward, terminated, **kwargs)

View File

@@ -1,5 +1,5 @@
import os import os
os.environ['MUJOCO_GL'] = 'egl' os.environ['MUJOCO_GL'] = os.getenv("MUJOCO_GL", 'egl')
os.environ['LAZY_LEGACY_OP'] = '0' os.environ['LAZY_LEGACY_OP'] = '0'
os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1" os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1"
os.environ['TORCH_LOGS'] = "+recompiles" os.environ['TORCH_LOGS'] = "+recompiles"

View File

@@ -39,18 +39,14 @@ class OfflineTrainer(Trainer):
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),}) f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
return results return results
def train(self): def _load_dataset(self):
"""Train a TD-MPC2 agent.""" """Load dataset for offline training."""
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
'Offline training only supports multitask training with mt30 or mt80 task sets.'
# Load data
fp = Path(os.path.join(self.cfg.data_dir, '*.pt')) fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
fps = sorted(glob(str(fp))) fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}' assert len(fps) > 0, f'No data found at {fp}'
print(f'Found {len(fps)} files in {fp}') print(f'Found {len(fps)} files in {fp}')
assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \ if len(fps) < (20 if self.cfg.task == 'mt80' else 4):
f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.' print(f'WARNING: expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.')
# Create buffer for sampling # Create buffer for sampling
_cfg = deepcopy(self.cfg) _cfg = deepcopy(self.cfg)
@@ -59,15 +55,20 @@ class OfflineTrainer(Trainer):
_cfg.steps = _cfg.buffer_size _cfg.steps = _cfg.buffer_size
self.buffer = Buffer(_cfg) self.buffer = Buffer(_cfg)
for fp in tqdm(fps, desc='Loading data'): for fp in tqdm(fps, desc='Loading data'):
td = torch.load(fp) td = torch.load(fp, weights_only=False)
assert td.shape[1] == _cfg.episode_length, \ assert td.shape[1] == _cfg.episode_length, \
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \ f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
f'please double-check your config.' f'please double-check your config.'
for i in range(len(td)): self.buffer.load(td)
self.buffer.add(td[i])
expected_episodes = _cfg.buffer_size // _cfg.episode_length expected_episodes = _cfg.buffer_size // _cfg.episode_length
assert self.buffer.num_eps == expected_episodes, \ if self.buffer.num_eps != expected_episodes:
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.' print(f'WARNING: buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes for {self.cfg.task} task set.')
def train(self):
"""Train a TD-MPC2 agent."""
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
'Offline training only supports multitask training with mt30 or mt80 task sets.'
self._load_dataset()
print(f'Training agent for {self.cfg.steps} iterations...') print(f'Training agent for {self.cfg.steps} iterations...')
metrics = {} metrics = {}
@@ -80,7 +81,7 @@ class OfflineTrainer(Trainer):
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0: if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
metrics = { metrics = {
'iteration': i, 'iteration': i,
'total_time': time() - self._start_time, 'elapsed_time': time() - self._start_time,
} }
metrics.update(train_metrics) metrics.update(train_metrics)
if i % self.cfg.eval_freq == 0: if i % self.cfg.eval_freq == 0:

View File

@@ -1,6 +1,5 @@
from time import time from time import time
import numpy as np
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from trainer.base import Trainer from trainer.base import Trainer
@@ -17,20 +16,22 @@ class OnlineTrainer(Trainer):
def common_metrics(self): def common_metrics(self):
"""Return a dictionary of current metrics.""" """Return a dictionary of current metrics."""
elapsed_time = time() - self._start_time
return dict( return dict(
step=self._step, step=self._step,
episode=self._ep_idx, episode=self._ep_idx,
total_time=time() - self._start_time, elapsed_time=elapsed_time,
steps_per_second=self._step / elapsed_time
) )
def eval(self): def eval(self):
"""Evaluate a TD-MPC2 agent.""" """Evaluate a TD-MPC2 agent."""
ep_rewards, ep_successes = [], [] ep_rewards, ep_successes, ep_lengths = [], [], []
for i in range(self.cfg.eval_episodes): for i in range(self.cfg.eval_episodes // self.cfg.num_envs):
obs, done, ep_reward, t = self.env.reset(), False, 0, 0 obs, done, ep_reward, t = self.env.reset(), torch.tensor(False), 0, 0
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.init(self.env, enabled=(i==0)) self.logger.video.init(self.env, enabled=(i==0))
while not done: while not done.any():
torch.compiler.cudagraph_mark_step_begin() torch.compiler.cudagraph_mark_step_begin()
action = self.agent.act(obs, t0=t==0, eval_mode=True) action = self.agent.act(obs, t0=t==0, eval_mode=True)
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
@@ -38,43 +39,49 @@ class OnlineTrainer(Trainer):
t += 1 t += 1
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.record(self.env) self.logger.video.record(self.env)
assert done.all(), 'Vectorized environments must reset all environments at once.'
ep_rewards.append(ep_reward) ep_rewards.append(ep_reward)
ep_successes.append(info['success']) ep_successes.append(info['success'])
ep_lengths.append(t)
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.save(self._step) self.logger.video.save(self._step)
return dict( return dict(
episode_reward=np.nanmean(ep_rewards), episode_reward=torch.cat(ep_rewards).mean(),
episode_success=np.nanmean(ep_successes), episode_success=info['success'].mean(),
episode_length= torch.tensor(ep_lengths, dtype=torch.float32).mean(),
) )
def to_td(self, obs, action=None, reward=None): def to_td(self, obs, action=None, reward=None, terminated=None):
"""Creates a TensorDict for a new episode.""" """Creates a TensorDict for a new episode."""
if isinstance(obs, dict): if isinstance(obs, dict):
obs = TensorDict(obs, batch_size=(), device='cpu') obs = TensorDict(obs, batch_size=(), device='cpu')
else: else:
obs = obs.unsqueeze(0).cpu() obs = obs.unsqueeze(0).cpu()
if action is None: if action is None:
action_val = -1 if self.cfg.action_space == 'discrete' else float('nan') action = torch.full_like(self.env.rand_act(), float('nan'))
action = torch.full_like(self.env.rand_act(), action_val)
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
if terminated is None:
terminated = torch.tensor(float('nan')).repeat(self.cfg.num_envs)
td = TensorDict( td = TensorDict(
obs=obs, obs=obs,
action=action.unsqueeze(0), action=action.unsqueeze(0),
reward=reward.unsqueeze(0), reward=reward.unsqueeze(0),
batch_size=(1,)) terminated=terminated.unsqueeze(0),
batch_size=(1, self.cfg.num_envs,))
return td return td
def train(self): def train(self):
"""Train a TD-MPC2 agent.""" """Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, False train_metrics, done, eval_next = {}, torch.tensor(True), True
while self._step <= self.cfg.steps: while self._step <= self.cfg.steps:
# Evaluate agent periodically # Evaluate agent periodically
if self._step % self.cfg.eval_freq == 0: if self._step % self.cfg.eval_freq == 0:
eval_next = True eval_next = True
# Reset environment # Reset environment
if done: if done.any():
assert done.all(), 'Vectorized environments must reset all environments at once.'
if eval_next: if eval_next:
eval_metrics = self.eval() eval_metrics = self.eval()
eval_metrics.update(self.common_metrics()) eval_metrics.update(self.common_metrics())
@@ -82,13 +89,19 @@ class OnlineTrainer(Trainer):
eval_next = False eval_next = False
if self._step > 0: if self._step > 0:
if info['terminated'].any() and not self.cfg.episodic:
raise ValueError('Termination detected but you are not in episodic mode. ' \
'Set `episodic=true` to enable support for terminations.')
tds = torch.cat(self._tds)
train_metrics.update( train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), episode_reward=tds['reward'].nansum(0).mean(),
episode_success=info['success'], episode_success=info['success'].nanmean(),
episode_length=len(self._tds),
episode_terminated=info['terminated'].nanmean(),
) )
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx = self.buffer.add(tds)
obs = self.env.reset() obs = self.env.reset()
self._tds = [self.to_td(obs)] self._tds = [self.to_td(obs)]
@@ -98,25 +111,22 @@ class OnlineTrainer(Trainer):
action = self.agent.act(obs, t0=len(self._tds)==1) action = self.agent.act(obs, t0=len(self._tds)==1)
else: else:
action = self.env.rand_act() action = self.env.rand_act()
if self.cfg.action_space == 'discrete':
# exploration schedule
# minimum 0.01, maximum 0.05, anneal over 20k steps
if torch.rand(1) < 0.01 + (0.05 - 0.01) * min(1, self._step / 20000):
action = self.env.rand_act()
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward)) self._tds.append(self.to_td(obs, action, reward, info['terminated']))
# Update agent # Update agent
if self._step >= self.cfg.seed_steps: if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps: if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps num_updates = int(self.cfg.seed_steps / self.cfg.steps_per_update)
print('Pretraining agent on seed data...') print('Pretraining agent on seed data...')
else: else:
num_updates = 1 num_updates = max(1, int(self.cfg.num_envs / self.cfg.steps_per_update))
for _ in range(num_updates): for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer) _train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics) train_metrics.update(_train_metrics)
if self._step == self.cfg.seed_steps:
print('Pretraining complete.')
self._step += 1 self._step += self.cfg.num_envs
self.logger.finish(self.agent) self.logger.finish(self.agent)