Merge pull request #49 from nicklashansen/speedups

Speedups
This commit is contained in:
Nicklas Hansen
2024-11-10 12:33:32 -08:00
committed by GitHub
18 changed files with 327 additions and 251 deletions

View File

@@ -12,6 +12,13 @@ Official implementation of
---- ----
**Announcement: training just got ~4.5x faster!**
Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. `compile=true` is available in state-based online RL at the moment, and we expect to roll out support across all settings in the coming months. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility!
----
## Overview ## Overview
TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across **104** continuous control tasks spanning multiple domains, with a *single* set of hyperparameters (*right*). We further demonstrate the scalability of TD-MPC**2** by training a single 317M parameter agent to perform **80** tasks across multiple domains, embodiments, and action spaces (*left*). TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across **104** continuous control tasks spanning multiple domains, with a *single* set of hyperparameters (*right*). We further demonstrate the scalability of TD-MPC**2** by training a single 317M parameter agent to perform **80** tasks across multiple domains, embodiments, and action spaces (*left*).

View File

@@ -0,0 +1 @@
for i in {0..3}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt30/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done

View File

@@ -0,0 +1 @@
for i in {0..19}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt80/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done

View File

@@ -5,51 +5,42 @@ channels:
- conda-forge - conda-forge
- defaults - defaults
dependencies: dependencies:
- cudatoolkit=11.7 - glew=2.2.0
- glew=2.1.0 - glib=2.78.4
- glib=2.68.4 - pip=24.0
- pip=21.0 - python=3.9
- python=3.9.0 - pytorch
- pytorch>=2.2.2 - pytorch-cuda=12.4
- torchvision>=0.16.2 - torchvision
- pip: - pip:
- absl-py==2.0.0 - absl-py==2.1.0
- "cython<3" - "cython<3"
- dm-control==1.0.8 - dm-control==1.0.8
- glfw==2.7.0
- ffmpeg==1.4 - ffmpeg==1.4
- glfw==2.6.4 - imageio==2.34.1
- imageio-ffmpeg==0.4.9
- h5py==3.11.0
- hydra-core==1.3.2 - hydra-core==1.3.2
- hydra-submitit-launcher==1.2.0 - hydra-submitit-launcher==1.2.0
- imageio==2.33.1
- imageio-ffmpeg==0.4.9
- kornia==0.7.1
- moviepy==1.0.3
- mujoco==2.3.1
- mujoco-py==2.1.2.14
- numpy==1.23.5
- omegaconf==2.3.0
- open3d==0.18.0
- opencv-contrib-python==4.9.0.80
- opencv-python==4.9.0.80
- pandas==2.1.4
- sapien==2.2.1
- submitit==1.5.1 - submitit==1.5.1
- setuptools==65.5.0 - setuptools==65.5.0
- patchelf==0.17.2.1 - patchelf==0.17.2.1
- protobuf==4.25.2 - omegaconf==2.3.0
- pillow==10.2.0 - moviepy==1.0.3
- pyquaternion==0.9.9 - mujoco==2.3.1
- tensordict-nightly==2024.3.26 - numpy==1.24.4
- tensordict-nightly
- torchrl-nightly
- kornia==0.7.2
- termcolor==2.4.0 - termcolor==2.4.0
- torchrl-nightly==2024.3.26 - tqdm==4.66.4
- transforms3d==0.4.1 - pandas==2.0.3
- trimesh==4.0.9 - wandb==0.17.4
- tqdm==4.66.1
- wandb==0.16.2
- wheel==0.38.0 - wheel==0.38.0
#################### ####################
# Gym: # Gym:
# (unmaintained but required for maniskill2/meta-world/myosuite) # (unmaintained but required for maniskill2/meta-world)
# - gym==0.21.0 # - gym==0.21.0
#################### ####################
# ManiSkill2: # ManiSkill2:
@@ -61,6 +52,5 @@ dependencies:
# - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb # - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb
#################### ####################
# MyoSuite: # MyoSuite:
# (requires gym==0.13 which conflicts with meta-world / mani-skill2)
# - myosuite # - myosuite
#################### ####################

View File

@@ -12,7 +12,7 @@ class Buffer():
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._device = torch.device('cuda') self._device = torch.device('cuda:0')
self._capacity = min(cfg.buffer_size, cfg.steps) self._capacity = min(cfg.buffer_size, cfg.steps)
self._sampler = SliceSampler( self._sampler = SliceSampler(
num_slices=self.cfg.batch_size, num_slices=self.cfg.batch_size,
@@ -41,8 +41,8 @@ class Buffer():
return ReplayBuffer( return ReplayBuffer(
storage=storage, storage=storage,
sampler=self._sampler, sampler=self._sampler,
pin_memory=True, pin_memory=False,
prefetch=1, prefetch=0,
batch_size=self._batch_size, batch_size=self._batch_size,
) )
@@ -58,32 +58,30 @@ class Buffer():
total_bytes = bytes_per_step*self._capacity total_bytes = bytes_per_step*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB') print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory # Heuristic: decide whether to use CUDA or CPU memory
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu'
print(f'Using {storage_device.upper()} memory for storage.') print(f'Using {storage_device.upper()} memory for storage.')
self._storage_device = torch.device(storage_device)
return self._reserve_buffer( return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device(storage_device)) LazyTensorStorage(self._capacity, device=self._storage_device)
) )
def _to_device(self, *args, device=None):
if device is None:
device = self._device
return (arg.to(device, non_blocking=True) \
if arg is not None else None for arg in args)
def _prepare_batch(self, td): def _prepare_batch(self, td):
""" """
Prepare a sampled batch for training (post-processing). Prepare a sampled batch for training (post-processing).
Expects `td` to be a TensorDict with batch size TxB. Expects `td` to be a TensorDict with batch size TxB.
""" """
obs = td['obs'] td = td.select("obs", "action", "reward", "task", strict=False).to(self._device, non_blocking=True)
action = td['action'][1:] obs = td.get('obs').contiguous()
reward = td['reward'][1:].unsqueeze(-1) action = td.get('action')[1:].contiguous()
task = td['task'][0] if 'task' in td.keys() else None reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
return self._to_device(obs, action, reward, task) task = td.get('task', None)
if task is not None:
task = task[0].contiguous()
return obs, action, reward, task
def add(self, td): def add(self, td):
"""Add an episode to the buffer.""" """Add an episode to the buffer."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
if self._num_eps == 0: if self._num_eps == 0:
self._buffer = self._init(td) self._buffer = self._init(td)
self._buffer.extend(td) self._buffer.extend(td)

View File

@@ -1,8 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from functorch import combine_state_for_ensemble from tensordict import from_modules
from copy import deepcopy
class Ensemble(nn.Module): class Ensemble(nn.Module):
""" """
@@ -11,14 +11,18 @@ class Ensemble(nn.Module):
def __init__(self, modules, **kwargs): def __init__(self, modules, **kwargs):
super().__init__() super().__init__()
modules = nn.ModuleList(modules) # combine_state_for_ensemble causes graph breaks
fn, params, _ = combine_state_for_ensemble(modules) self.params = from_modules(*modules, as_module=True)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs) with self.params[0].data.to("meta").to_module(modules[0]):
self.params = nn.ParameterList([nn.Parameter(p) for p in params]) self.module = deepcopy(modules[0])
self._repr = str(modules) self._repr = str(modules)
def _call(self, params, *args, **kwargs):
with params.to_module(self.module):
return self.module(*args, **kwargs)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.vmap([p for p in self.params], (), *args, **kwargs) return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
def __repr__(self): def __repr__(self):
return 'Vectorized ' + self._repr return 'Vectorized ' + self._repr
@@ -32,13 +36,13 @@ class ShiftAug(nn.Module):
def __init__(self, pad=3): def __init__(self, pad=3):
super().__init__() super().__init__()
self.pad = pad self.pad = pad
self.padding = tuple([self.pad] * 4)
def forward(self, x): def forward(self, x):
x = x.float() x = x.float()
n, _, h, w = x.size() n, _, h, w = x.size()
assert h == w assert h == w
padding = tuple([self.pad] * 4) x = F.pad(x, self.padding, 'replicate')
x = F.pad(x, padding, 'replicate')
eps = 1.0 / (h + 2 * self.pad) eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h] arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
@@ -59,7 +63,7 @@ class PixelPreprocess(nn.Module):
super().__init__() super().__init__()
def forward(self, x): def forward(self, x):
return x.div_(255.).sub_(0.5) return x.div(255.).sub(0.5)
class SimNorm(nn.Module): class SimNorm(nn.Module):
@@ -87,11 +91,13 @@ class NormedLinear(nn.Linear):
Linear layer with LayerNorm, activation, and optionally dropout. Linear layer with LayerNorm, activation, and optionally dropout.
""" """
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs): def __init__(self, *args, dropout=0., act=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features) self.ln = nn.LayerNorm(self.out_features)
if act is None:
act = nn.Mish(inplace=False)
self.act = act self.act = act
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None self.dropout = nn.Dropout(dropout, inplace=False) if dropout else None
def forward(self, x): def forward(self, x):
x = super().forward(x) x = super().forward(x)
@@ -130,9 +136,9 @@ def conv(in_shape, num_channels, act=None):
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64 assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
layers = [ layers = [
ShiftAug(), PixelPreprocess(), ShiftAug(), PixelPreprocess(),
nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=True), nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=False),
nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=True), nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=False),
nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=True), nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=False),
nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()] nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()]
if act: if act:
layers.append(act) layers.append(act)

View File

@@ -1,10 +1,11 @@
import dataclasses
import os import os
import datetime import datetime
import re import re
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from termcolor import colored from termcolor import colored
from omegaconf import OmegaConf
from common import TASK_SET from common import TASK_SET
@@ -116,7 +117,7 @@ class Logger:
print_run(cfg) print_run(cfg)
self.project = cfg.get("wandb_project", "none") self.project = cfg.get("wandb_project", "none")
self.entity = cfg.get("wandb_entity", "none") self.entity = cfg.get("wandb_entity", "none")
if cfg.disable_wandb or self.project == "none" or self.entity == "none": if not cfg.enable_wandb or self.project == "none" or self.entity == "none":
print(colored("Wandb disabled.", "blue", attrs=["bold"])) print(colored("Wandb disabled.", "blue", attrs=["bold"]))
cfg.save_agent = False cfg.save_agent = False
cfg.save_video = False cfg.save_video = False
@@ -133,7 +134,7 @@ class Logger:
group=self._group, group=self._group,
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"], tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
dir=self._log_dir, dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True), config=dataclasses.asdict(cfg),
) )
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
self._wandb = wandb self._wandb = wandb

View File

@@ -9,30 +9,27 @@ def soft_ce(pred, target, cfg):
return -(target * pred).sum(-1, keepdim=True) return -(target * pred).sum(-1, keepdim=True)
@torch.jit.script
def log_std(x, low, dif): def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1) return low + 0.5 * dif * (torch.tanh(x) + 1)
@torch.jit.script
def _gaussian_residual(eps, log_std): def _gaussian_residual(eps, log_std):
return -0.5 * eps.pow(2) - log_std return -0.5 * eps.pow(2) - log_std
@torch.jit.script
def _gaussian_logprob(residual): def _gaussian_logprob(residual):
return residual - 0.5 * torch.log(2 * torch.pi) log2pi = 1.8378770351409912
return residual - 0.5 * log2pi
def gaussian_logprob(eps, log_std, size=None): def gaussian_logprob(eps, log_std, size=None):
"""Compute Gaussian log probability.""" """Compute Gaussian log probability."""
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
if size is None: if size is None:
size = eps.size(-1) size = eps.shape[-1]
return _gaussian_logprob(residual) * size return _gaussian_logprob(residual) * size
@torch.jit.script
def _squash(pi): def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
@@ -45,7 +42,6 @@ def squash(mu, pi, log_pi):
return mu, pi, log_pi return mu, pi, log_pi
@torch.jit.script
def symlog(x): def symlog(x):
""" """
Symmetric logarithmic function. Symmetric logarithmic function.
@@ -54,7 +50,6 @@ def symlog(x):
return torch.sign(x) * torch.log(1 + torch.abs(x)) return torch.sign(x) * torch.log(1 + torch.abs(x))
@torch.jit.script
def symexp(x): def symexp(x):
""" """
Symmetric exponential function. Symmetric exponential function.
@@ -70,26 +65,33 @@ def two_hot(x, cfg):
elif cfg.num_bins == 1: elif cfg.num_bins == 1:
return symlog(x) return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1) x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size)
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx).unsqueeze(-1)
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) soft_two_hot = torch.zeros(x.shape[0], cfg.num_bins, device=x.device, dtype=x.dtype)
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset) bin_idx = bin_idx.long()
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) soft_two_hot = soft_two_hot.scatter(1, bin_idx.unsqueeze(1), 1 - bin_offset)
soft_two_hot = soft_two_hot.scatter(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
return soft_two_hot return soft_two_hot
DREG_BINS = None
def two_hot_inv(x, cfg): def two_hot_inv(x, cfg):
"""Converts a batch of soft two-hot encoded vectors to scalars.""" """Converts a batch of soft two-hot encoded vectors to scalars."""
global DREG_BINS
if cfg.num_bins == 0: if cfg.num_bins == 0:
return x return x
elif cfg.num_bins == 1: elif cfg.num_bins == 1:
return symexp(x) return symexp(x)
if DREG_BINS is None: dreg_bins = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device, dtype=x.dtype)
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
x = F.softmax(x, dim=-1) x = F.softmax(x, dim=-1)
x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True) x = torch.sum(x * dreg_bins, dim=-1, keepdim=True)
return symexp(x) return symexp(x)
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
logits = p.log()
# Generate Gumbel noise
gumbels = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
) # ~Gumbel(0,1)
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
return y_soft.argmax(-1)

View File

@@ -1,5 +1,7 @@
import dataclasses
import re import re
from pathlib import Path from pathlib import Path
from typing import Any
import hydra import hydra
from omegaconf import OmegaConf from omegaconf import OmegaConf
@@ -7,6 +9,23 @@ from omegaconf import OmegaConf
from common import MODEL_SIZE, TASK_SET from common import MODEL_SIZE, TASK_SET
def cfg_to_dataclass(cfg, frozen=False):
"""
Converts an OmegaConf config to a dataclass object.
This prevents graph breaks when used with torch.compile.
"""
cfg_dict = OmegaConf.to_container(cfg)
fields = []
for key, value in cfg_dict.items():
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
dataclass_name = "Config"
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
def get(self, val, default=None):
return getattr(self, val, default)
dataclass.get = get
return dataclass()
def parse_cfg(cfg: OmegaConf) -> OmegaConf: def parse_cfg(cfg: OmegaConf) -> OmegaConf:
""" """
Parses a Hydra config. Mostly for convenience. Parses a Hydra config. Mostly for convenience.
@@ -53,9 +72,14 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
if cfg.multitask: if cfg.multitask:
cfg.task_title = cfg.task.upper() cfg.task_title = cfg.task.upper()
# Account for slight inconsistency in task_dim for the mt30 experiments # Account for slight inconsistency in task_dim for the mt30 experiments
cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.model_size in {1, 317} else 64 cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.get('model_size', 5) in {1, 317} else 64
else: else:
cfg.task_dim = 0 cfg.task_dim = 0
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
return cfg # Check torch.compile compatibility
if cfg.get('compile', False):
assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.'
assert not cfg.multitask, 'torch.compile does not support multitask training at the moment.'
return cfg_to_dataclass(cfg)

View File

@@ -1,48 +1,49 @@
import torch import torch
from torch.nn import Buffer
class RunningScale(torch.nn.Module):
class RunningScale:
"""Running trimmed scale estimator.""" """Running trimmed scale estimator."""
def __init__(self, cfg): def __init__(self, cfg):
super().__init__()
self.cfg = cfg self.cfg = cfg
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda')) self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')) self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')))
def state_dict(self): def state_dict(self):
return dict(value=self._value, percentiles=self._percentiles) return dict(value=self.value, percentiles=self._percentiles)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self._value.data.copy_(state_dict['value']) self.value.copy_(state_dict['value'])
self._percentiles.data.copy_(state_dict['percentiles']) self._percentiles.copy_(state_dict['percentiles'])
@property def _positions(self, x_shape):
def value(self): positions = self._percentiles * (x_shape-1) / 100
return self._value.cpu().item() floored = torch.floor(positions)
ceiled = floored + 1
ceiled = torch.where(ceiled > x_shape - 1, x_shape - 1, ceiled)
weight_ceiled = positions-floored
weight_floored = 1.0 - weight_ceiled
return floored.long(), ceiled.long(), weight_floored.unsqueeze(1), weight_ceiled.unsqueeze(1)
def _percentile(self, x): def _percentile(self, x):
x_dtype, x_shape = x.dtype, x.shape x_dtype, x_shape = x.dtype, x.shape
x = x.view(x.shape[0], -1) x = x.flatten(1, x.ndim-1)
in_sorted, _ = torch.sort(x, dim=0) in_sorted = torch.sort(x, dim=0).values
positions = self._percentiles * (x.shape[0]-1) / 100 floored, ceiled, weight_floored, weight_ceiled = self._positions(x.shape[0])
floored = torch.floor(positions) d0 = in_sorted[floored] * weight_floored
ceiled = floored + 1 d1 = in_sorted[ceiled] * weight_ceiled
ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1 return (d0+d1).reshape(-1, *x_shape[1:]).to(x_dtype)
weight_ceiled = positions-floored
weight_floored = 1.0 - weight_ceiled
d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
return (d0+d1).view(-1, *x_shape[1:]).type(x_dtype)
def update(self, x): def update(self, x):
percentiles = self._percentile(x.detach()) percentiles = self._percentile(x.detach())
value = torch.clamp(percentiles[1] - percentiles[0], min=1.) value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
self._value.data.lerp_(value, self.cfg.tau) self.value.data.lerp_(value, self.cfg.tau)
def __call__(self, x, update=False): def forward(self, x, update=False):
if update: if update:
self.update(x) self.update(x)
return x * (1/self.value) return x / self.value
def __repr__(self): def __repr__(self):
return f'RunningScale(S: {self.value})' return f'RunningScale(S: {self.value})'

View File

@@ -1,11 +1,10 @@
from copy import deepcopy from copy import deepcopy
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from common import layers, math, init from common import layers, math, init
from tensordict.nn import TensorDictParams
class WorldModel(nn.Module): class WorldModel(nn.Module):
""" """
@@ -18,7 +17,7 @@ class WorldModel(nn.Module):
self.cfg = cfg self.cfg = cfg
if cfg.multitask: if cfg.multitask:
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1) self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim) self.register_buffer("_action_masks", torch.zeros(len(cfg.tasks), cfg.action_dim))
for i in range(len(cfg.tasks)): for i in range(len(cfg.tasks)):
self._action_masks[i, :cfg.action_dims[i]] = 1. self._action_masks[i, :cfg.action_dims[i]] = 1.
self._encoder = layers.enc(cfg) self._encoder = layers.enc(cfg)
@@ -27,24 +26,41 @@ class WorldModel(nn.Module):
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init) self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params[-2]]) init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
self.log_std_min = torch.tensor(cfg.log_std_min) self.register_buffer("log_std_min", torch.tensor(cfg.log_std_min))
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min self.register_buffer("log_std_dif", torch.tensor(cfg.log_std_max) - self.log_std_min)
self.init()
def init(self):
# Create params
self._detach_Qs_params = TensorDictParams(self._Qs.params.data, no_convert=True)
self._target_Qs_params = TensorDictParams(self._Qs.params.data.clone(), no_convert=True)
# Create modules
with self._detach_Qs_params.data.to("meta").to_module(self._Qs.module):
self._detach_Qs = deepcopy(self._Qs)
self._target_Qs = deepcopy(self._Qs)
# Assign params to modules
self._detach_Qs.params = self._detach_Qs_params
self._target_Qs.params = self._target_Qs_params
def __repr__(self):
repr = 'TD-MPC2 World Model\n'
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions']
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]):
repr += f"{modules[i]}: {m}\n"
repr += "Learnable parameters: {:,}".format(self.total_params)
return repr
@property @property
def total_params(self): def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters() if p.requires_grad)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
"""
Overriding `to` method to also move additional tensors to device.
"""
super().to(*args, **kwargs) super().to(*args, **kwargs)
if self.cfg.multitask: self.init()
self._action_masks = self._action_masks.to(*args, **kwargs)
self.log_std_min = self.log_std_min.to(*args, **kwargs)
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
return self return self
def train(self, mode=True): def train(self, mode=True):
@@ -55,25 +71,11 @@ class WorldModel(nn.Module):
self._target_Qs.train(False) self._target_Qs.train(False)
return self return self
def track_q_grad(self, mode=True):
"""
Enables/disables gradient tracking of Q-networks.
Avoids unnecessary computation during policy optimization.
This method also enables/disables gradients for task embeddings.
"""
for p in self._Qs.parameters():
p.requires_grad_(mode)
if self.cfg.multitask:
for p in self._task_emb.parameters():
p.requires_grad_(mode)
def soft_update_target_Q(self): def soft_update_target_Q(self):
""" """
Soft-update target Q-networks using Polyak averaging. Soft-update target Q-networks using Polyak averaging.
""" """
with torch.no_grad(): self._target_Qs_params.lerp_(self._detach_Qs_params, self.cfg.tau)
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
p_target.data.lerp_(p.data, self.cfg.tau)
def task_emb(self, x, task): def task_emb(self, x, task):
""" """
@@ -147,7 +149,7 @@ class WorldModel(nn.Module):
return mu, pi, log_pi, log_std return mu, pi, log_pi, log_std
def Q(self, z, a, task, return_type='min', target=False): def Q(self, z, a, task, return_type='min', target=False, detach=False):
""" """
Predict state-action value. Predict state-action value.
`return_type` can be one of [`min`, `avg`, `all`]: `return_type` can be one of [`min`, `avg`, `all`]:
@@ -162,11 +164,19 @@ class WorldModel(nn.Module):
z = self.task_emb(z, task) z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
out = (self._target_Qs if target else self._Qs)(z) if target:
qnet = self._target_Qs
elif detach:
qnet = self._detach_Qs
else:
qnet = self._Qs
out = qnet(z)
if return_type == 'all': if return_type == 'all':
return out return out
Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)] qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
Q1, Q2 = math.two_hot_inv(Q1, self.cfg), math.two_hot_inv(Q2, self.cfg) Q = math.two_hot_inv(out[qidx], self.cfg)
return torch.min(Q1, Q2) if return_type == 'min' else (Q1 + Q2) / 2 if return_type == "min":
return Q.min(0).values
return Q.sum(0) / 2

View File

@@ -65,7 +65,7 @@ simnorm_dim: 8
wandb_project: ??? wandb_project: ???
wandb_entity: ??? wandb_entity: ???
wandb_silent: false wandb_silent: false
disable_wandb: true enable_wandb: true
save_csv: true save_csv: true
# misc # misc
@@ -86,3 +86,6 @@ action_dims: ???
episode_lengths: ??? episode_lengths: ???
seed_steps: ??? seed_steps: ???
bin_size: ??? bin_size: ???
# speedups
compile: False

View File

@@ -24,9 +24,11 @@ class MyoSuiteWrapper(gym.Wrapper):
self.cfg = cfg self.cfg = cfg
self.camera_id = 'hand_side_inter' self.camera_id = 'hand_side_inter'
def reset(self):
return self.env.reset()[0]
def step(self, action): def step(self, action):
obs, reward, _, info = self.env.step(action.copy()) obs, reward, _, _, info = self.env.step(action.copy())
obs = obs.astype(np.float32)
info['success'] = info['solved'] info['success'] = info['solved']
return obs, reward, False, info return obs, reward, False, info
@@ -48,7 +50,8 @@ def make_env(cfg):
raise ValueError('Unknown task:', cfg.task) raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.' assert cfg.obs == 'state', 'This task only supports state observations.'
import myosuite import myosuite
env = gym.make(MYOSUITE_TASKS[cfg.task]) from myosuite.utils import gym as gym_utils
env = gym_utils.make(MYOSUITE_TASKS[cfg.task])
env = MyoSuiteWrapper(env, cfg) env = MyoSuiteWrapper(env, cfg)
env = TimeLimit(env, max_episode_steps=100) env = TimeLimit(env, max_episode_steps=100)
env.max_episode_steps = env._max_episode_steps env.max_episode_steps = env._max_episode_steps

View File

@@ -1,13 +1,13 @@
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from common import math from common import math
from common.scale import RunningScale from common.scale import RunningScale
from common.world_model import WorldModel from common.world_model import WorldModel
from tensordict import TensorDict
class TDMPC2: class TDMPC2(torch.nn.Module):
""" """
TD-MPC2 agent. Implements training + inference. TD-MPC2 agent. Implements training + inference.
Can be used for both single-task and multi-task experiments, Can be used for both single-task and multi-task experiments,
@@ -15,23 +15,41 @@ class TDMPC2:
""" """
def __init__(self, cfg): def __init__(self, cfg):
super().__init__()
self.cfg = cfg self.cfg = cfg
self.device = torch.device('cuda') self.device = torch.device('cuda:0')
self.model = WorldModel(cfg).to(self.device) self.model = WorldModel(cfg).to(self.device)
self.optim = torch.optim.Adam([ self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()}, {'params': self.model._dynamics.parameters()},
{'params': self.model._reward.parameters()}, {'params': self.model._reward.parameters()},
{'params': self.model._Qs.parameters()}, {'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []} {'params': self.model._task_emb.parameters() if self.cfg.multitask else []
], lr=self.cfg.lr) }
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5) ], lr=self.cfg.lr, capturable=True)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=True)
self.model.eval() self.model.eval()
self.scale = RunningScale(cfg) self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
self.discount = torch.tensor( self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda' [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile:
print('Compiling update function with torch.compile...')
self._update = torch.compile(self._update, mode="reduce-overhead")
@property
def plan(self):
_plan_val = getattr(self, "_plan_val", None)
if _plan_val is not None:
return _plan_val
if self.cfg.compile:
plan = torch.compile(self._plan, mode="reduce-overhead")
else:
plan = self._plan
self._plan_val = plan
return self._plan_val
def _get_discount(self, episode_length): def _get_discount(self, episode_length):
""" """
@@ -84,10 +102,10 @@ class TDMPC2:
obs = obs.to(self.device, non_blocking=True).unsqueeze(0) obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
if task is not None: if task is not None:
task = torch.tensor([task], device=self.device) task = torch.tensor([task], device=self.device)
z = self.model.encode(obs, task)
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
else: else:
z = self.model.encode(obs, task)
a = self.model.pi(z, task)[int(not eval_mode)][0] a = self.model.pi(z, task)[int(not eval_mode)][0]
return a.cpu() return a.cpu()
@@ -98,12 +116,13 @@ class TDMPC2:
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task) z = self.model.next(z, actions[t], task)
G += discount * reward G = G + discount * reward
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
@torch.no_grad() @torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None): def _plan(self, obs, t0=False, eval_mode=False, task=None):
""" """
Plan a sequence of actions using the learned world model. Plan a sequence of actions using the learned world model.
@@ -117,6 +136,7 @@ class TDMPC2:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
# Sample policy trajectories # Sample policy trajectories
z = self.model.encode(obs, task)
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1) _z = z.repeat(self.cfg.num_pi_trajs, 1)
@@ -128,7 +148,7 @@ class TDMPC2:
# Initialize state and parameters # Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1) z = z.repeat(self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device) std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
if not t0: if not t0:
mean[:-1] = self._prev_mean[1:] mean[:-1] = self._prev_mean[1:]
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
@@ -139,36 +159,37 @@ class TDMPC2:
for _ in range(self.cfg.iterations): for _ in range(self.cfg.iterations):
# Sample actions # Sample actions
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \ r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \ actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r
.clamp(-1, 1) actions_sample = actions_sample.clamp(-1, 1)
actions[:, self.cfg.num_pi_trajs:] = actions_sample
if self.cfg.multitask: if self.cfg.multitask:
actions = actions * self.model._action_masks[task] actions = actions * self.model._action_masks[task]
# Compute elite actions # Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num_(0) value = self._estimate_value(z, actions, task).nan_to_num(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters # Update parameters
max_value = elite_value.max(0)[0] max_value = elite_value.max(0).values
score = torch.exp(self.cfg.temperature*(elite_value - max_value)) score = torch.exp(self.cfg.temperature*(elite_value - max_value))
score /= score.sum(0) score = score / score.sum(0)
mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \ std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
.clamp_(self.cfg.min_std, self.cfg.max_std) std = std.clamp(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask: if self.cfg.multitask:
mean = mean * self.model._action_masks[task] mean = mean * self.model._action_masks[task]
std = std * self.model._action_masks[task] std = std * self.model._action_masks[task]
# Select action # Select action
score = score.squeeze(1).cpu().numpy() rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
self._prev_mean = mean
a, std = actions[0], std[0] a, std = actions[0], std[0]
if not eval_mode: if not eval_mode:
a += std * torch.randn(self.cfg.action_dim, device=std.device) a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
return a.clamp_(-1, 1) self._prev_mean.copy_(mean)
return a.clamp(-1, 1)
def update_pi(self, zs, task): def update_pi(self, zs, task):
""" """
@@ -181,10 +202,8 @@ class TDMPC2:
Returns: Returns:
float: Loss of the policy update. float: Loss of the policy update.
""" """
self.pi_optim.zero_grad(set_to_none=True)
self.model.track_q_grad(False)
_, pis, log_pis, _ = self.model.pi(zs, task) _, pis, log_pis, _ = self.model.pi(zs, task)
qs = self.model.Q(zs, pis, task, return_type='avg') qs = self.model.Q(zs, pis, task, return_type='avg', detach=True)
self.scale.update(qs[0]) self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
@@ -192,11 +211,11 @@ class TDMPC2:
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
pi_loss.backward() pi_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step() self.pi_optim.step()
self.model.track_q_grad(True) self.pi_optim.zero_grad(set_to_none=True)
return pi_loss.item() return pi_loss.detach(), pi_grad_norm
@torch.no_grad() @torch.no_grad()
def _td_target(self, next_z, reward, task): def _td_target(self, next_z, reward, task):
@@ -215,6 +234,71 @@ class TDMPC2:
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
def _update(self, obs, action, reward, task=None):
# Compute targets
with torch.no_grad():
next_z = self.model.encode(obs[1:], task)
td_targets = self._td_target(next_z, reward, task)
# Prepare for update
self.model.train()
# Latent rollout
zs = torch.empty(self.cfg.horizon+1, self.cfg.batch_size, self.cfg.latent_dim, device=self.device)
z = self.model.encode(obs[0], task)
zs[0] = z
consistency_loss = 0
for t, (_action, _next_z) in enumerate(zip(action.unbind(0), next_z.unbind(0))):
z = self.model.next(z, _action, task)
consistency_loss = consistency_loss + F.mse_loss(z, _next_z) * self.cfg.rho**t
zs[t+1] = z
# Predictions
_zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task)
# Compute losses
reward_loss, value_loss = 0, 0
for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))):
reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t
for _, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)):
value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t
consistency_loss = consistency_loss / self.cfg.horizon
reward_loss = reward_loss / self.cfg.horizon
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
total_loss = (
self.cfg.consistency_coef * consistency_loss +
self.cfg.reward_coef * reward_loss +
self.cfg.value_coef * value_loss
)
# Update model
total_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)
self.optim.step()
self.optim.zero_grad(set_to_none=True)
# Update policy
pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task)
# Update target Q-functions
self.model.soft_update_target_Q()
# Return training statistics
self.model.eval()
return TensorDict({
"consistency_loss": consistency_loss,
"reward_loss": reward_loss,
"value_loss": value_loss,
"pi_loss": pi_loss,
"total_loss": total_loss,
"grad_norm": grad_norm,
"pi_grad_norm": pi_grad_norm,
"pi_scale": self.scale.value,
}).detach().mean()
def update(self, buffer): def update(self, buffer):
""" """
Main update function. Corresponds to one iteration of model learning. Main update function. Corresponds to one iteration of model learning.
@@ -226,65 +310,8 @@ class TDMPC2:
dict: Dictionary of training statistics. dict: Dictionary of training statistics.
""" """
obs, action, reward, task = buffer.sample() obs, action, reward, task = buffer.sample()
kwargs = {}
# Compute targets if task is not None:
with torch.no_grad(): kwargs["task"] = task
next_z = self.model.encode(obs[1:], task) torch.compiler.cudagraph_mark_step_begin()
td_targets = self._td_target(next_z, reward, task) return self._update(obs, action, reward, **kwargs)
# Prepare for update
self.optim.zero_grad(set_to_none=True)
self.model.train()
# Latent rollout
zs = torch.empty(self.cfg.horizon+1, self.cfg.batch_size, self.cfg.latent_dim, device=self.device)
z = self.model.encode(obs[0], task)
zs[0] = z
consistency_loss = 0
for t in range(self.cfg.horizon):
z = self.model.next(z, action[t], task)
consistency_loss += F.mse_loss(z, next_z[t]) * self.cfg.rho**t
zs[t+1] = z
# Predictions
_zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task)
# Compute losses
reward_loss, value_loss = 0, 0
for t in range(self.cfg.horizon):
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
for q in range(self.cfg.num_q):
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
consistency_loss *= (1/self.cfg.horizon)
reward_loss *= (1/self.cfg.horizon)
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
total_loss = (
self.cfg.consistency_coef * consistency_loss +
self.cfg.reward_coef * reward_loss +
self.cfg.value_coef * value_loss
)
# Update model
total_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)
self.optim.step()
# Update policy
pi_loss = self.update_pi(zs.detach(), task)
# Update target Q-functions
self.model.soft_update_target_Q()
# Return training statistics
self.model.eval()
return {
"consistency_loss": float(consistency_loss.mean().item()),
"reward_loss": float(reward_loss.mean().item()),
"value_loss": float(value_loss.mean().item()),
"pi_loss": pi_loss,
"total_loss": float(total_loss.mean().item()),
"grad_norm": float(grad_norm),
"pi_scale": float(self.scale.value),
}

View File

@@ -1,6 +1,8 @@
import os import os
os.environ['MUJOCO_GL'] = 'egl' os.environ['MUJOCO_GL'] = 'egl'
os.environ['LAZY_LEGACY_OP'] = '0' os.environ['LAZY_LEGACY_OP'] = '0'
os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1"
os.environ['TORCH_LOGS'] = "+recompiles"
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import torch import torch
@@ -18,6 +20,7 @@ from trainer.online_trainer import OnlineTrainer
from common.logger import Logger from common.logger import Logger
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
@hydra.main(config_name='config', config_path='.') @hydra.main(config_name='config', config_path='.')

View File

@@ -8,7 +8,6 @@ class Trainer:
self.buffer = buffer self.buffer = buffer
self.logger = logger self.logger = logger
print('Architecture:', self.agent.model) print('Architecture:', self.agent.model)
print("Learnable parameters: {:,}".format(self.agent.model.total_params))
def eval(self): def eval(self):
"""Evaluate a TD-MPC2 agent.""" """Evaluate a TD-MPC2 agent."""

View File

@@ -27,6 +27,7 @@ class OfflineTrainer(Trainer):
for _ in range(self.cfg.eval_episodes): for _ in range(self.cfg.eval_episodes):
obs, done, ep_reward, t = self.env.reset(task_idx), False, 0, 0 obs, done, ep_reward, t = self.env.reset(task_idx), False, 0, 0
while not done: while not done:
torch.compiler.cudagraph_mark_step_begin()
action = self.agent.act(obs, t0=t==0, eval_mode=True, task=task_idx) action = self.agent.act(obs, t0=t==0, eval_mode=True, task=task_idx)
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
ep_reward += reward ep_reward += reward
@@ -44,13 +45,12 @@ class OfflineTrainer(Trainer):
'Offline training only supports multitask training with mt30 or mt80 task sets.' 'Offline training only supports multitask training with mt30 or mt80 task sets.'
# Load data # Load data
assert self.cfg.task in self.cfg.data_dir, \
f'Expected data directory {self.cfg.data_dir} to contain {self.cfg.task}, ' \
f'please double-check your config.'
fp = Path(os.path.join(self.cfg.data_dir, '*.pt')) fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
fps = sorted(glob(str(fp))) fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}' assert len(fps) > 0, f'No data found at {fp}'
print(f'Found {len(fps)} files in {fp}') print(f'Found {len(fps)} files in {fp}')
assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \
f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.'
# Create buffer for sampling # Create buffer for sampling
_cfg = deepcopy(self.cfg) _cfg = deepcopy(self.cfg)
@@ -65,8 +65,9 @@ class OfflineTrainer(Trainer):
f'please double-check your config.' f'please double-check your config.'
for i in range(len(td)): for i in range(len(td)):
self.buffer.add(td[i]) self.buffer.add(td[i])
assert self.buffer.num_eps == self.buffer.capacity, \ expected_episodes = _cfg.buffer_size // _cfg.episode_length
f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.' assert self.buffer.num_eps == expected_episodes, \
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
print(f'Training agent for {self.cfg.steps} iterations...') print(f'Training agent for {self.cfg.steps} iterations...')
metrics = {} metrics = {}

View File

@@ -3,7 +3,6 @@ from time import time
import numpy as np import numpy as np
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from trainer.base import Trainer from trainer.base import Trainer
@@ -32,6 +31,7 @@ class OnlineTrainer(Trainer):
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.init(self.env, enabled=(i==0)) self.logger.video.init(self.env, enabled=(i==0))
while not done: while not done:
torch.compiler.cudagraph_mark_step_begin()
action = self.agent.act(obs, t0=t==0, eval_mode=True) action = self.agent.act(obs, t0=t==0, eval_mode=True)
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
ep_reward += reward ep_reward += reward
@@ -57,18 +57,17 @@ class OnlineTrainer(Trainer):
action = torch.full_like(self.env.rand_act(), float('nan')) action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan'))
td = TensorDict(dict( td = TensorDict(
obs=obs, obs=obs,
action=action.unsqueeze(0), action=action.unsqueeze(0),
reward=reward.unsqueeze(0), reward=reward.unsqueeze(0),
), batch_size=(1,)) batch_size=(1,))
return td return td
def train(self): def train(self):
"""Train a TD-MPC2 agent.""" """Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, True train_metrics, done, eval_next = {}, True, False
while self._step <= self.cfg.steps: while self._step <= self.cfg.steps:
# Evaluate agent periodically # Evaluate agent periodically
if self._step % self.cfg.eval_freq == 0: if self._step % self.cfg.eval_freq == 0:
eval_next = True eval_next = True