Merge pull request #46 from vmoens/cudagraphs
[WIP,Feature] Add cudagraphs and compile option
This commit is contained in:
@@ -1,56 +1,46 @@
|
|||||||
name: tdmpc2
|
name: graph
|
||||||
channels:
|
channels:
|
||||||
- pytorch-nightly
|
- pytorch-nightly
|
||||||
- nvidia
|
- nvidia
|
||||||
- conda-forge
|
- conda-forge
|
||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
- cudatoolkit=11.7
|
- glew=2.2.0
|
||||||
- glew=2.1.0
|
- glib=2.78.4
|
||||||
- glib=2.68.4
|
- pip=24.0
|
||||||
- pip=21.0
|
- python=3.9
|
||||||
- python=3.9.0
|
- pytorch
|
||||||
- pytorch>=2.2.2
|
- pytorch-cuda=12.4
|
||||||
- torchvision>=0.16.2
|
- torchvision
|
||||||
- pip:
|
- pip:
|
||||||
- absl-py==2.0.0
|
- absl-py==2.1.0
|
||||||
- "cython<3"
|
|
||||||
- dm-control==1.0.8
|
- dm-control==1.0.8
|
||||||
|
- glfw==2.7.0
|
||||||
- ffmpeg==1.4
|
- ffmpeg==1.4
|
||||||
- glfw==2.6.4
|
- imageio==2.34.1
|
||||||
|
- imageio-ffmpeg==0.4.9
|
||||||
|
- h5py==3.11.0
|
||||||
- hydra-core==1.3.2
|
- hydra-core==1.3.2
|
||||||
- hydra-submitit-launcher==1.2.0
|
- hydra-submitit-launcher==1.2.0
|
||||||
- imageio==2.33.1
|
- submitit==1.5.1
|
||||||
- imageio-ffmpeg==0.4.9
|
- omegaconf==2.3.0
|
||||||
- kornia==0.7.1
|
|
||||||
- moviepy==1.0.3
|
- moviepy==1.0.3
|
||||||
- mujoco==2.3.1
|
- mujoco==2.3.1
|
||||||
- mujoco-py==2.1.2.14
|
- numpy==1.24.4
|
||||||
- numpy==1.23.5
|
- tensordict-nightly
|
||||||
- omegaconf==2.3.0
|
- torchrl-nightly
|
||||||
- open3d==0.18.0
|
- kornia==0.7.2
|
||||||
- opencv-contrib-python==4.9.0.80
|
|
||||||
- opencv-python==4.9.0.80
|
|
||||||
- pandas==2.1.4
|
|
||||||
- sapien==2.2.1
|
|
||||||
- submitit==1.5.1
|
|
||||||
- setuptools==65.5.0
|
|
||||||
- patchelf==0.17.2.1
|
|
||||||
- protobuf==4.25.2
|
|
||||||
- pillow==10.2.0
|
|
||||||
- pyquaternion==0.9.9
|
|
||||||
- tensordict-nightly==2024.3.26
|
|
||||||
- termcolor==2.4.0
|
- termcolor==2.4.0
|
||||||
- torchrl-nightly==2024.3.26
|
- tqdm==4.66.4
|
||||||
- transforms3d==0.4.1
|
- pandas==2.0.3
|
||||||
- trimesh==4.0.9
|
- wandb==0.17.4
|
||||||
- tqdm==4.66.1
|
- matplotlib==3.7.5
|
||||||
- wandb==0.16.2
|
- seaborn==0.13.2
|
||||||
- wheel==0.38.0
|
- gpustat==1.1.1
|
||||||
####################
|
####################
|
||||||
# Gym:
|
# Gym:
|
||||||
# (unmaintained but required for maniskill2/meta-world/myosuite)
|
# (unmaintained but required for maniskill2/meta-world/myosuite)
|
||||||
# - gym==0.21.0
|
- gym==0.21.0
|
||||||
####################
|
####################
|
||||||
# ManiSkill2:
|
# ManiSkill2:
|
||||||
# (requires gym==0.21.0 which occasionally breaks)
|
# (requires gym==0.21.0 which occasionally breaks)
|
||||||
|
|||||||
34
requirements.txt
Normal file
34
requirements.txt
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
absl-py
|
||||||
|
cython
|
||||||
|
dm-control
|
||||||
|
ffmpeg
|
||||||
|
glfw
|
||||||
|
hydra-core
|
||||||
|
hydra-submitit-launcher
|
||||||
|
imageio
|
||||||
|
imageio-ffmpeg
|
||||||
|
kornia
|
||||||
|
moviepy
|
||||||
|
mujoco
|
||||||
|
mujoco-py
|
||||||
|
numpy<2
|
||||||
|
omegaconf
|
||||||
|
open3d
|
||||||
|
opencv-contrib-python
|
||||||
|
opencv-python
|
||||||
|
pandas
|
||||||
|
sapien
|
||||||
|
submitit
|
||||||
|
setuptools
|
||||||
|
patchelf
|
||||||
|
protobuf
|
||||||
|
pillow
|
||||||
|
pyquaternion
|
||||||
|
tensordict-nightly
|
||||||
|
termcolor
|
||||||
|
torchrl-nightly
|
||||||
|
transforms3d
|
||||||
|
trimesh
|
||||||
|
tqdm
|
||||||
|
wandb
|
||||||
|
wheel
|
||||||
@@ -12,7 +12,7 @@ class Buffer():
|
|||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._device = torch.device('cuda')
|
self._device = torch.device('cuda:0')
|
||||||
self._capacity = min(cfg.buffer_size, cfg.steps)
|
self._capacity = min(cfg.buffer_size, cfg.steps)
|
||||||
self._sampler = SliceSampler(
|
self._sampler = SliceSampler(
|
||||||
num_slices=self.cfg.batch_size,
|
num_slices=self.cfg.batch_size,
|
||||||
@@ -28,7 +28,7 @@ class Buffer():
|
|||||||
def capacity(self):
|
def capacity(self):
|
||||||
"""Return the capacity of the buffer."""
|
"""Return the capacity of the buffer."""
|
||||||
return self._capacity
|
return self._capacity
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_eps(self):
|
def num_eps(self):
|
||||||
"""Return the number of episodes in the buffer."""
|
"""Return the number of episodes in the buffer."""
|
||||||
@@ -41,8 +41,8 @@ class Buffer():
|
|||||||
return ReplayBuffer(
|
return ReplayBuffer(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
sampler=self._sampler,
|
sampler=self._sampler,
|
||||||
pin_memory=True,
|
pin_memory=False,
|
||||||
prefetch=1,
|
prefetch=0,
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -58,32 +58,30 @@ class Buffer():
|
|||||||
total_bytes = bytes_per_step*self._capacity
|
total_bytes = bytes_per_step*self._capacity
|
||||||
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
print(f'Storage required: {total_bytes/1e9:.2f} GB')
|
||||||
# Heuristic: decide whether to use CUDA or CPU memory
|
# Heuristic: decide whether to use CUDA or CPU memory
|
||||||
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu'
|
storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu'
|
||||||
print(f'Using {storage_device.upper()} memory for storage.')
|
print(f'Using {storage_device.upper()} memory for storage.')
|
||||||
|
self._storage_device = torch.device(storage_device)
|
||||||
return self._reserve_buffer(
|
return self._reserve_buffer(
|
||||||
LazyTensorStorage(self._capacity, device=torch.device(storage_device))
|
LazyTensorStorage(self._capacity, device=self._storage_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _to_device(self, *args, device=None):
|
|
||||||
if device is None:
|
|
||||||
device = self._device
|
|
||||||
return (arg.to(device, non_blocking=True) \
|
|
||||||
if arg is not None else None for arg in args)
|
|
||||||
|
|
||||||
def _prepare_batch(self, td):
|
def _prepare_batch(self, td):
|
||||||
"""
|
"""
|
||||||
Prepare a sampled batch for training (post-processing).
|
Prepare a sampled batch for training (post-processing).
|
||||||
Expects `td` to be a TensorDict with batch size TxB.
|
Expects `td` to be a TensorDict with batch size TxB.
|
||||||
"""
|
"""
|
||||||
obs = td['obs']
|
td = td.select("obs", "action", "reward", "task", strict=False).to(self._device, non_blocking=True)
|
||||||
action = td['action'][1:]
|
obs = td.get('obs').contiguous()
|
||||||
reward = td['reward'][1:].unsqueeze(-1)
|
action = td.get('action')[1:].contiguous()
|
||||||
task = td['task'][0] if 'task' in td.keys() else None
|
reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
|
||||||
return self._to_device(obs, action, reward, task)
|
task = td.get('task', None)
|
||||||
|
if task is not None:
|
||||||
|
task = task[0].contiguous()
|
||||||
|
return obs, action, reward, task
|
||||||
|
|
||||||
def add(self, td):
|
def add(self, td):
|
||||||
"""Add an episode to the buffer."""
|
"""Add an episode to the buffer."""
|
||||||
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps
|
td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
|
||||||
if self._num_eps == 0:
|
if self._num_eps == 0:
|
||||||
self._buffer = self._init(td)
|
self._buffer = self._init(td)
|
||||||
self._buffer.extend(td)
|
self._buffer.extend(td)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from functorch import combine_state_for_ensemble
|
from tensordict import from_modules
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
class Ensemble(nn.Module):
|
class Ensemble(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -11,14 +11,18 @@ class Ensemble(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, modules, **kwargs):
|
def __init__(self, modules, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
modules = nn.ModuleList(modules)
|
# combine_state_for_ensemble causes graph breaks
|
||||||
fn, params, _ = combine_state_for_ensemble(modules)
|
self.params = from_modules(*modules, as_module=True)
|
||||||
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs)
|
with self.params[0].data.to("meta").to_module(modules[0]):
|
||||||
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
self.module = deepcopy(modules[0])
|
||||||
self._repr = str(modules)
|
self._repr = str(modules)
|
||||||
|
|
||||||
|
def _call(self, params, *args, **kwargs):
|
||||||
|
with params.to_module(self.module):
|
||||||
|
return self.module(*args, **kwargs)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return 'Vectorized ' + self._repr
|
return 'Vectorized ' + self._repr
|
||||||
@@ -32,13 +36,13 @@ class ShiftAug(nn.Module):
|
|||||||
def __init__(self, pad=3):
|
def __init__(self, pad=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pad = pad
|
self.pad = pad
|
||||||
|
self.padding = tuple([self.pad] * 4)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x.float()
|
x = x.float()
|
||||||
n, _, h, w = x.size()
|
n, _, h, w = x.size()
|
||||||
assert h == w
|
assert h == w
|
||||||
padding = tuple([self.pad] * 4)
|
x = F.pad(x, self.padding, 'replicate')
|
||||||
x = F.pad(x, padding, 'replicate')
|
|
||||||
eps = 1.0 / (h + 2 * self.pad)
|
eps = 1.0 / (h + 2 * self.pad)
|
||||||
arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h]
|
arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h]
|
||||||
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
||||||
@@ -59,7 +63,7 @@ class PixelPreprocess(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x.div_(255.).sub_(0.5)
|
return x.div(255.).sub(0.5)
|
||||||
|
|
||||||
|
|
||||||
class SimNorm(nn.Module):
|
class SimNorm(nn.Module):
|
||||||
@@ -67,17 +71,17 @@ class SimNorm(nn.Module):
|
|||||||
Simplicial normalization.
|
Simplicial normalization.
|
||||||
Adapted from https://arxiv.org/abs/2204.00616.
|
Adapted from https://arxiv.org/abs/2204.00616.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = cfg.simnorm_dim
|
self.dim = cfg.simnorm_dim
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shp = x.shape
|
shp = x.shape
|
||||||
x = x.view(*shp[:-1], -1, self.dim)
|
x = x.view(*shp[:-1], -1, self.dim)
|
||||||
x = F.softmax(x, dim=-1)
|
x = F.softmax(x, dim=-1)
|
||||||
return x.view(*shp)
|
return x.view(*shp)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"SimNorm(dim={self.dim})"
|
return f"SimNorm(dim={self.dim})"
|
||||||
|
|
||||||
@@ -87,18 +91,20 @@ class NormedLinear(nn.Linear):
|
|||||||
Linear layer with LayerNorm, activation, and optionally dropout.
|
Linear layer with LayerNorm, activation, and optionally dropout.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs):
|
def __init__(self, *args, dropout=0., act=None, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.ln = nn.LayerNorm(self.out_features)
|
self.ln = nn.LayerNorm(self.out_features)
|
||||||
|
if act is None:
|
||||||
|
act = nn.Mish(inplace=False)
|
||||||
self.act = act
|
self.act = act
|
||||||
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
|
self.dropout = nn.Dropout(dropout, inplace=False) if dropout else None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = super().forward(x)
|
x = super().forward(x)
|
||||||
if self.dropout:
|
if self.dropout:
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
return self.act(self.ln(x))
|
return self.act(self.ln(x))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
|
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
|
||||||
return f"NormedLinear(in_features={self.in_features}, "\
|
return f"NormedLinear(in_features={self.in_features}, "\
|
||||||
@@ -130,9 +136,9 @@ def conv(in_shape, num_channels, act=None):
|
|||||||
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
|
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
|
||||||
layers = [
|
layers = [
|
||||||
ShiftAug(), PixelPreprocess(),
|
ShiftAug(), PixelPreprocess(),
|
||||||
nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=True),
|
nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=False),
|
||||||
nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=True),
|
nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=False),
|
||||||
nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=True),
|
nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=False),
|
||||||
nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()]
|
nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()]
|
||||||
if act:
|
if act:
|
||||||
layers.append(act)
|
layers.append(act)
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from omegaconf import OmegaConf
|
from torchrl._utils import timeit
|
||||||
|
|
||||||
from common import TASK_SET
|
from common import TASK_SET
|
||||||
|
|
||||||
|
|
||||||
@@ -133,7 +133,7 @@ class Logger:
|
|||||||
group=self._group,
|
group=self._group,
|
||||||
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
|
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
|
||||||
dir=self._log_dir,
|
dir=self._log_dir,
|
||||||
config=OmegaConf.to_container(cfg, resolve=True),
|
config=dataclasses.asdict(cfg),
|
||||||
)
|
)
|
||||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
@@ -238,3 +238,5 @@ class Logger:
|
|||||||
self._log_dir / "eval.csv", header=keys, index=None
|
self._log_dir / "eval.csv", header=keys, index=None
|
||||||
)
|
)
|
||||||
self._print(d, category)
|
self._print(d, category)
|
||||||
|
timeit.print()
|
||||||
|
timeit.erase()
|
||||||
|
|||||||
@@ -9,30 +9,30 @@ def soft_ce(pred, target, cfg):
|
|||||||
return -(target * pred).sum(-1, keepdim=True)
|
return -(target * pred).sum(-1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def log_std(x, low, dif):
|
def log_std(x, low, dif):
|
||||||
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _gaussian_residual(eps, log_std):
|
def _gaussian_residual(eps, log_std):
|
||||||
return -0.5 * eps.pow(2) - log_std
|
return -0.5 * eps.pow(2) - log_std
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _gaussian_logprob(residual):
|
def _gaussian_logprob(residual):
|
||||||
return residual - 0.5 * torch.log(2 * torch.pi)
|
log2pi = 1.8378770351409912
|
||||||
|
return residual - 0.5 * log2pi
|
||||||
|
|
||||||
|
|
||||||
def gaussian_logprob(eps, log_std, size=None):
|
def gaussian_logprob(eps, log_std, size=None):
|
||||||
"""Compute Gaussian log probability."""
|
"""Compute Gaussian log probability."""
|
||||||
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
||||||
if size is None:
|
if size is None:
|
||||||
size = eps.size(-1)
|
size = eps.shape[-1]
|
||||||
return _gaussian_logprob(residual) * size
|
return _gaussian_logprob(residual) * size
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _squash(pi):
|
def _squash(pi):
|
||||||
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ def squash(mu, pi, log_pi):
|
|||||||
return mu, pi, log_pi
|
return mu, pi, log_pi
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def symlog(x):
|
def symlog(x):
|
||||||
"""
|
"""
|
||||||
Symmetric logarithmic function.
|
Symmetric logarithmic function.
|
||||||
@@ -54,7 +54,7 @@ def symlog(x):
|
|||||||
return torch.sign(x) * torch.log(1 + torch.abs(x))
|
return torch.sign(x) * torch.log(1 + torch.abs(x))
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def symexp(x):
|
def symexp(x):
|
||||||
"""
|
"""
|
||||||
Symmetric exponential function.
|
Symmetric exponential function.
|
||||||
@@ -70,26 +70,32 @@ def two_hot(x, cfg):
|
|||||||
elif cfg.num_bins == 1:
|
elif cfg.num_bins == 1:
|
||||||
return symlog(x)
|
return symlog(x)
|
||||||
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
|
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
|
||||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long()
|
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size)
|
||||||
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1)
|
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx).unsqueeze(-1)
|
||||||
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device)
|
soft_two_hot = torch.zeros(x.shape[0], cfg.num_bins, device=x.device, dtype=x.dtype)
|
||||||
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset)
|
bin_idx = bin_idx.long()
|
||||||
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
|
soft_two_hot = soft_two_hot.scatter(1, bin_idx.unsqueeze(1), 1 - bin_offset)
|
||||||
|
soft_two_hot = soft_two_hot.scatter(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
|
||||||
return soft_two_hot
|
return soft_two_hot
|
||||||
|
|
||||||
|
|
||||||
DREG_BINS = None
|
|
||||||
|
|
||||||
|
|
||||||
def two_hot_inv(x, cfg):
|
def two_hot_inv(x, cfg):
|
||||||
"""Converts a batch of soft two-hot encoded vectors to scalars."""
|
"""Converts a batch of soft two-hot encoded vectors to scalars."""
|
||||||
global DREG_BINS
|
|
||||||
if cfg.num_bins == 0:
|
if cfg.num_bins == 0:
|
||||||
return x
|
return x
|
||||||
elif cfg.num_bins == 1:
|
elif cfg.num_bins == 1:
|
||||||
return symexp(x)
|
return symexp(x)
|
||||||
if DREG_BINS is None:
|
dreg_bins = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device, dtype=x.dtype)
|
||||||
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
|
|
||||||
x = F.softmax(x, dim=-1)
|
x = F.softmax(x, dim=-1)
|
||||||
x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True)
|
x = torch.sum(x * dreg_bins, dim=-1, keepdim=True)
|
||||||
return symexp(x)
|
return symexp(x)
|
||||||
|
|
||||||
|
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
|
||||||
|
logits = p.log()
|
||||||
|
# Generate Gumbel noise
|
||||||
|
gumbels = (
|
||||||
|
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
||||||
|
) # ~Gumbel(0,1)
|
||||||
|
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
|
||||||
|
y_soft = gumbels.softmax(dim)
|
||||||
|
return y_soft.argmax(-1)
|
||||||
|
|||||||
@@ -1,48 +1,49 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from torch.nn import Buffer
|
||||||
|
|
||||||
|
class RunningScale(torch.nn.Module):
|
||||||
class RunningScale:
|
|
||||||
"""Running trimmed scale estimator."""
|
"""Running trimmed scale estimator."""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))
|
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
|
||||||
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))
|
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')))
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return dict(value=self._value, percentiles=self._percentiles)
|
return dict(value=self.value, percentiles=self._percentiles)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self._value.data.copy_(state_dict['value'])
|
self.value.copy_(state_dict['value'])
|
||||||
self._percentiles.data.copy_(state_dict['percentiles'])
|
self._percentiles.copy_(state_dict['percentiles'])
|
||||||
|
|
||||||
@property
|
def _positions(self, x_shape):
|
||||||
def value(self):
|
positions = self._percentiles * (x_shape-1) / 100
|
||||||
return self._value.cpu().item()
|
floored = torch.floor(positions)
|
||||||
|
ceiled = floored + 1
|
||||||
|
ceiled = torch.where(ceiled > x_shape - 1, x_shape - 1, ceiled)
|
||||||
|
weight_ceiled = positions-floored
|
||||||
|
weight_floored = 1.0 - weight_ceiled
|
||||||
|
return floored.long(), ceiled.long(), weight_floored.unsqueeze(1), weight_ceiled.unsqueeze(1)
|
||||||
|
|
||||||
def _percentile(self, x):
|
def _percentile(self, x):
|
||||||
x_dtype, x_shape = x.dtype, x.shape
|
x_dtype, x_shape = x.dtype, x.shape
|
||||||
x = x.view(x.shape[0], -1)
|
x = x.flatten(1, x.ndim-1)
|
||||||
in_sorted, _ = torch.sort(x, dim=0)
|
in_sorted = torch.sort(x, dim=0).values
|
||||||
positions = self._percentiles * (x.shape[0]-1) / 100
|
floored, ceiled, weight_floored, weight_ceiled = self._positions(x.shape[0])
|
||||||
floored = torch.floor(positions)
|
d0 = in_sorted[floored] * weight_floored
|
||||||
ceiled = floored + 1
|
d1 = in_sorted[ceiled] * weight_ceiled
|
||||||
ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1
|
return (d0+d1).reshape(-1, *x_shape[1:]).to(x_dtype)
|
||||||
weight_ceiled = positions-floored
|
|
||||||
weight_floored = 1.0 - weight_ceiled
|
|
||||||
d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
|
|
||||||
d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
|
|
||||||
return (d0+d1).view(-1, *x_shape[1:]).type(x_dtype)
|
|
||||||
|
|
||||||
def update(self, x):
|
def update(self, x):
|
||||||
percentiles = self._percentile(x.detach())
|
percentiles = self._percentile(x.detach())
|
||||||
value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
|
value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
|
||||||
self._value.data.lerp_(value, self.cfg.tau)
|
self.value.data.lerp_(value, self.cfg.tau)
|
||||||
|
|
||||||
def __call__(self, x, update=False):
|
def forward(self, x, update=False):
|
||||||
if update:
|
if update:
|
||||||
self.update(x)
|
self.update(x)
|
||||||
return x * (1/self.value)
|
return x / self.value
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'RunningScale(S: {self.value})'
|
return f'RunningScale(S: {self.value})'
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from common import layers, math, init
|
from common import layers, math, init
|
||||||
|
from tensordict import TensorDict
|
||||||
|
from tensordict.nn import TensorDictParams
|
||||||
|
|
||||||
class WorldModel(nn.Module):
|
class WorldModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -18,7 +19,7 @@ class WorldModel(nn.Module):
|
|||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
if cfg.multitask:
|
if cfg.multitask:
|
||||||
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
|
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
|
||||||
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim)
|
self.register_buffer("_action_masks", torch.zeros(len(cfg.tasks), cfg.action_dim))
|
||||||
for i in range(len(cfg.tasks)):
|
for i in range(len(cfg.tasks)):
|
||||||
self._action_masks[i, :cfg.action_dims[i]] = 1.
|
self._action_masks[i, :cfg.action_dims[i]] = 1.
|
||||||
self._encoder = layers.enc(cfg)
|
self._encoder = layers.enc(cfg)
|
||||||
@@ -27,26 +28,35 @@ class WorldModel(nn.Module):
|
|||||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||||
self.apply(init.weight_init)
|
self.apply(init.weight_init)
|
||||||
init.zero_([self._reward[-1].weight, self._Qs.params[-2]])
|
init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
|
||||||
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
|
|
||||||
self.log_std_min = torch.tensor(cfg.log_std_min)
|
self.register_buffer("log_std_min", torch.tensor(cfg.log_std_min))
|
||||||
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min
|
self.register_buffer("log_std_dif", torch.tensor(cfg.log_std_max) - self.log_std_min)
|
||||||
|
self.init()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
# Create params
|
||||||
|
self._detach_Qs_params = TensorDictParams(self._Qs.params.data, no_convert=True)
|
||||||
|
self._target_Qs_params = TensorDictParams(self._Qs.params.data.clone(), no_convert=True)
|
||||||
|
|
||||||
|
# Create modules
|
||||||
|
with self._detach_Qs_params.data.to("meta").to_module(self._Qs.module):
|
||||||
|
self._detach_Qs = deepcopy(self._Qs)
|
||||||
|
self._target_Qs = deepcopy(self._Qs)
|
||||||
|
|
||||||
|
# Assign params to modules
|
||||||
|
self._detach_Qs.params = self._detach_Qs_params
|
||||||
|
self._target_Qs.params = self._target_Qs_params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_params(self):
|
def total_params(self):
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
"""
|
|
||||||
Overriding `to` method to also move additional tensors to device.
|
|
||||||
"""
|
|
||||||
super().to(*args, **kwargs)
|
super().to(*args, **kwargs)
|
||||||
if self.cfg.multitask:
|
self.init()
|
||||||
self._action_masks = self._action_masks.to(*args, **kwargs)
|
|
||||||
self.log_std_min = self.log_std_min.to(*args, **kwargs)
|
|
||||||
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
"""
|
"""
|
||||||
Overriding `train` method to keep target Q-networks in eval mode.
|
Overriding `train` method to keep target Q-networks in eval mode.
|
||||||
@@ -55,26 +65,12 @@ class WorldModel(nn.Module):
|
|||||||
self._target_Qs.train(False)
|
self._target_Qs.train(False)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def track_q_grad(self, mode=True):
|
|
||||||
"""
|
|
||||||
Enables/disables gradient tracking of Q-networks.
|
|
||||||
Avoids unnecessary computation during policy optimization.
|
|
||||||
This method also enables/disables gradients for task embeddings.
|
|
||||||
"""
|
|
||||||
for p in self._Qs.parameters():
|
|
||||||
p.requires_grad_(mode)
|
|
||||||
if self.cfg.multitask:
|
|
||||||
for p in self._task_emb.parameters():
|
|
||||||
p.requires_grad_(mode)
|
|
||||||
|
|
||||||
def soft_update_target_Q(self):
|
def soft_update_target_Q(self):
|
||||||
"""
|
"""
|
||||||
Soft-update target Q-networks using Polyak averaging.
|
Soft-update target Q-networks using Polyak averaging.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
self._target_Qs_params.lerp_(self._detach_Qs_params, self.cfg.tau)
|
||||||
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
|
|
||||||
p_target.data.lerp_(p.data, self.cfg.tau)
|
|
||||||
|
|
||||||
def task_emb(self, x, task):
|
def task_emb(self, x, task):
|
||||||
"""
|
"""
|
||||||
Continuous task embedding for multi-task experiments.
|
Continuous task embedding for multi-task experiments.
|
||||||
@@ -109,7 +105,7 @@ class WorldModel(nn.Module):
|
|||||||
z = self.task_emb(z, task)
|
z = self.task_emb(z, task)
|
||||||
z = torch.cat([z, a], dim=-1)
|
z = torch.cat([z, a], dim=-1)
|
||||||
return self._dynamics(z)
|
return self._dynamics(z)
|
||||||
|
|
||||||
def reward(self, z, a, task):
|
def reward(self, z, a, task):
|
||||||
"""
|
"""
|
||||||
Predicts instantaneous (single-step) reward.
|
Predicts instantaneous (single-step) reward.
|
||||||
@@ -147,7 +143,7 @@ class WorldModel(nn.Module):
|
|||||||
|
|
||||||
return mu, pi, log_pi, log_std
|
return mu, pi, log_pi, log_std
|
||||||
|
|
||||||
def Q(self, z, a, task, return_type='min', target=False):
|
def Q(self, z, a, task, return_type='min', target=False, detach=False):
|
||||||
"""
|
"""
|
||||||
Predict state-action value.
|
Predict state-action value.
|
||||||
`return_type` can be one of [`min`, `avg`, `all`]:
|
`return_type` can be one of [`min`, `avg`, `all`]:
|
||||||
@@ -160,13 +156,21 @@ class WorldModel(nn.Module):
|
|||||||
|
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
z = self.task_emb(z, task)
|
z = self.task_emb(z, task)
|
||||||
|
|
||||||
z = torch.cat([z, a], dim=-1)
|
z = torch.cat([z, a], dim=-1)
|
||||||
out = (self._target_Qs if target else self._Qs)(z)
|
if target:
|
||||||
|
qnet = self._target_Qs
|
||||||
|
elif detach:
|
||||||
|
qnet = self._detach_Qs
|
||||||
|
else:
|
||||||
|
qnet = self._Qs
|
||||||
|
out = qnet(z)
|
||||||
|
|
||||||
if return_type == 'all':
|
if return_type == 'all':
|
||||||
return out
|
return out
|
||||||
|
|
||||||
Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)]
|
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
|
||||||
Q1, Q2 = math.two_hot_inv(Q1, self.cfg), math.two_hot_inv(Q2, self.cfg)
|
Q = math.two_hot_inv(out[qidx], self.cfg)
|
||||||
return torch.min(Q1, Q2) if return_type == 'min' else (Q1 + Q2) / 2
|
if return_type == "min":
|
||||||
|
return Q.min(0).values
|
||||||
|
return Q.sum(0) / 2
|
||||||
|
|||||||
@@ -86,3 +86,7 @@ action_dims: ???
|
|||||||
episode_lengths: ???
|
episode_lengths: ???
|
||||||
seed_steps: ???
|
seed_steps: ???
|
||||||
bin_size: ???
|
bin_size: ???
|
||||||
|
|
||||||
|
# compile
|
||||||
|
compile: False
|
||||||
|
cudagraphs: False
|
||||||
|
|||||||
202
tdmpc2/tdmpc2.py
202
tdmpc2/tdmpc2.py
@@ -1,13 +1,18 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import functools
|
||||||
|
from torchrl._utils import timeit
|
||||||
|
|
||||||
from common import math
|
from common import math
|
||||||
from common.scale import RunningScale
|
from common.scale import RunningScale
|
||||||
from common.world_model import WorldModel
|
from common.world_model import WorldModel
|
||||||
|
from tensordict.nn import CudaGraphModule
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
|
CG_WARMUP = 1000
|
||||||
|
|
||||||
|
|
||||||
class TDMPC2:
|
class TDMPC2(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
TD-MPC2 agent. Implements training + inference.
|
TD-MPC2 agent. Implements training + inference.
|
||||||
Can be used for both single-task and multi-task experiments,
|
Can be used for both single-task and multi-task experiments,
|
||||||
@@ -15,23 +20,71 @@ class TDMPC2:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.device = torch.device('cuda')
|
|
||||||
|
self.device = torch.device('cuda:0')
|
||||||
|
|
||||||
self.model = WorldModel(cfg).to(self.device)
|
self.model = WorldModel(cfg).to(self.device)
|
||||||
|
capturable = True
|
||||||
self.optim = torch.optim.Adam([
|
self.optim = torch.optim.Adam([
|
||||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||||
{'params': self.model._dynamics.parameters()},
|
{'params': self.model._dynamics.parameters()},
|
||||||
{'params': self.model._reward.parameters()},
|
{'params': self.model._reward.parameters()},
|
||||||
{'params': self.model._Qs.parameters()},
|
{'params': self.model._Qs.parameters()},
|
||||||
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []}
|
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
|
||||||
], lr=self.cfg.lr)
|
}
|
||||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5)
|
], lr=self.cfg.lr, capturable=capturable)
|
||||||
|
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=capturable)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.scale = RunningScale(cfg)
|
self.scale = RunningScale(cfg)
|
||||||
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
|
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
|
||||||
self.discount = torch.tensor(
|
self.discount = torch.tensor(
|
||||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda'
|
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||||
|
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||||
|
|
||||||
|
if cfg.compile:
|
||||||
|
mode = None if cfg.cudagraphs else "reduce-overhead"
|
||||||
|
print('compiling - update')
|
||||||
|
self._update = torch.compile(self._update, mode=mode)
|
||||||
|
|
||||||
|
if cfg.cudagraphs:
|
||||||
|
print('cudagraphs - update')
|
||||||
|
self._update = CudaGraphModule(self._update, warmup=CG_WARMUP)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def plan(self):
|
||||||
|
_plan_val = getattr(self, "_plan_val", None)
|
||||||
|
if _plan_val is not None:
|
||||||
|
return _plan_val
|
||||||
|
if self.cfg.cudagraphs:
|
||||||
|
print('cudagraphs - plan')
|
||||||
|
self._plan_dict = {
|
||||||
|
(True, True): functools.partial(self._plan, t0=True, eval_mode=True),
|
||||||
|
(False, True): functools.partial(self._plan, t0=False, eval_mode=True),
|
||||||
|
(True, False): functools.partial(self._plan, t0=True, eval_mode=False),
|
||||||
|
(False, False): functools.partial(self._plan, t0=False, eval_mode=False),
|
||||||
|
}
|
||||||
|
if self.cfg.compile:
|
||||||
|
print('compiling - plan')
|
||||||
|
mode = None
|
||||||
|
self._plan_dict = {k: torch.compile(func, mode=mode) for k, func in self._plan_dict.items()}
|
||||||
|
|
||||||
|
self._plan_dict = {k: CudaGraphModule(func, warmup=CG_WARMUP) for k, func in self._plan_dict.items()}
|
||||||
|
def plan(obs, t0=False, eval_mode=False, task=None):
|
||||||
|
if task is not None:
|
||||||
|
kwargs = {"task": task}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
torch.compiler.cudagraph_mark_step_begin()
|
||||||
|
return self._plan_dict[(t0, eval_mode)](obs=obs, **kwargs)
|
||||||
|
elif self.cfg.compile:
|
||||||
|
plan = torch.compile(self._plan, mode="reduce-overhead")
|
||||||
|
else:
|
||||||
|
plan = self._plan
|
||||||
|
self._plan_val = plan
|
||||||
|
return self._plan_val
|
||||||
|
|
||||||
def _get_discount(self, episode_length):
|
def _get_discount(self, episode_length):
|
||||||
"""
|
"""
|
||||||
@@ -51,7 +104,7 @@ class TDMPC2:
|
|||||||
def save(self, fp):
|
def save(self, fp):
|
||||||
"""
|
"""
|
||||||
Save state dict of the agent to filepath.
|
Save state dict of the agent to filepath.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fp (str): Filepath to save state dict to.
|
fp (str): Filepath to save state dict to.
|
||||||
"""
|
"""
|
||||||
@@ -60,7 +113,7 @@ class TDMPC2:
|
|||||||
def load(self, fp):
|
def load(self, fp):
|
||||||
"""
|
"""
|
||||||
Load a saved state dict from filepath (or dictionary) into current agent.
|
Load a saved state dict from filepath (or dictionary) into current agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fp (str or dict): Filepath or state dict to load.
|
fp (str or dict): Filepath or state dict to load.
|
||||||
"""
|
"""
|
||||||
@@ -71,23 +124,23 @@ class TDMPC2:
|
|||||||
def act(self, obs, t0=False, eval_mode=False, task=None):
|
def act(self, obs, t0=False, eval_mode=False, task=None):
|
||||||
"""
|
"""
|
||||||
Select an action by planning in the latent space of the world model.
|
Select an action by planning in the latent space of the world model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
obs (torch.Tensor): Observation from the environment.
|
obs (torch.Tensor): Observation from the environment.
|
||||||
t0 (bool): Whether this is the first observation in the episode.
|
t0 (bool): Whether this is the first observation in the episode.
|
||||||
eval_mode (bool): Whether to use the mean of the action distribution.
|
eval_mode (bool): Whether to use the mean of the action distribution.
|
||||||
task (int): Task index (only used for multi-task experiments).
|
task (int): Task index (only used for multi-task experiments).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Action to take in the environment.
|
torch.Tensor: Action to take in the environment.
|
||||||
"""
|
"""
|
||||||
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
|
obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
|
||||||
if task is not None:
|
if task is not None:
|
||||||
task = torch.tensor([task], device=self.device)
|
task = torch.tensor([task], device=self.device)
|
||||||
z = self.model.encode(obs, task)
|
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task)
|
a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
|
||||||
else:
|
else:
|
||||||
|
z = self.model.encode(obs, task)
|
||||||
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
||||||
return a.cpu()
|
return a.cpu()
|
||||||
|
|
||||||
@@ -98,15 +151,16 @@ class TDMPC2:
|
|||||||
for t in range(self.cfg.horizon):
|
for t in range(self.cfg.horizon):
|
||||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||||
z = self.model.next(z, actions[t], task)
|
z = self.model.next(z, actions[t], task)
|
||||||
G += discount * reward
|
G = G + discount * reward
|
||||||
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||||
|
discount = discount * discount_update
|
||||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def plan(self, z, t0=False, eval_mode=False, task=None):
|
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||||
"""
|
"""
|
||||||
Plan a sequence of actions using the learned world model.
|
Plan a sequence of actions using the learned world model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
z (torch.Tensor): Latent state from which to plan.
|
z (torch.Tensor): Latent state from which to plan.
|
||||||
t0 (bool): Whether this is the first observation in the episode.
|
t0 (bool): Whether this is the first observation in the episode.
|
||||||
@@ -115,8 +169,9 @@ class TDMPC2:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Action to take in the environment.
|
torch.Tensor: Action to take in the environment.
|
||||||
"""
|
"""
|
||||||
# Sample policy trajectories
|
# Sample policy trajectories
|
||||||
|
z = self.model.encode(obs, task)
|
||||||
if self.cfg.num_pi_trajs > 0:
|
if self.cfg.num_pi_trajs > 0:
|
||||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
||||||
@@ -128,52 +183,53 @@ class TDMPC2:
|
|||||||
# Initialize state and parameters
|
# Initialize state and parameters
|
||||||
z = z.repeat(self.cfg.num_samples, 1)
|
z = z.repeat(self.cfg.num_samples, 1)
|
||||||
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||||
std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
|
||||||
if not t0:
|
if not t0:
|
||||||
mean[:-1] = self._prev_mean[1:]
|
mean[:-1] = self._prev_mean[1:]
|
||||||
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
||||||
if self.cfg.num_pi_trajs > 0:
|
if self.cfg.num_pi_trajs > 0:
|
||||||
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
||||||
|
|
||||||
# Iterate MPPI
|
# Iterate MPPI
|
||||||
for _ in range(self.cfg.iterations):
|
for _ in range(self.cfg.iterations):
|
||||||
|
|
||||||
# Sample actions
|
# Sample actions
|
||||||
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \
|
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
|
||||||
torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \
|
actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r
|
||||||
.clamp(-1, 1)
|
actions_sample = actions_sample.clamp(-1, 1)
|
||||||
|
actions[:, self.cfg.num_pi_trajs:] = actions_sample
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
actions = actions * self.model._action_masks[task]
|
actions = actions * self.model._action_masks[task]
|
||||||
|
|
||||||
# Compute elite actions
|
# Compute elite actions
|
||||||
value = self._estimate_value(z, actions, task).nan_to_num_(0)
|
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
||||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||||
|
|
||||||
# Update parameters
|
# Update parameters
|
||||||
max_value = elite_value.max(0)[0]
|
max_value = elite_value.max(0).values
|
||||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
||||||
score /= score.sum(0)
|
score = score / score.sum(0)
|
||||||
mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9)
|
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
|
||||||
std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \
|
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
|
||||||
.clamp_(self.cfg.min_std, self.cfg.max_std)
|
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
mean = mean * self.model._action_masks[task]
|
mean = mean * self.model._action_masks[task]
|
||||||
std = std * self.model._action_masks[task]
|
std = std * self.model._action_masks[task]
|
||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
score = score.squeeze(1).cpu().numpy()
|
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
||||||
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)]
|
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
||||||
self._prev_mean = mean
|
|
||||||
a, std = actions[0], std[0]
|
a, std = actions[0], std[0]
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
a += std * torch.randn(self.cfg.action_dim, device=std.device)
|
a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
|
||||||
return a.clamp_(-1, 1)
|
self._prev_mean.copy_(mean)
|
||||||
|
return a.clamp(-1, 1)
|
||||||
|
|
||||||
def update_pi(self, zs, task):
|
def update_pi(self, zs, task):
|
||||||
"""
|
"""
|
||||||
Update policy using a sequence of latent states.
|
Update policy using a sequence of latent states.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
zs (torch.Tensor): Sequence of latent states.
|
zs (torch.Tensor): Sequence of latent states.
|
||||||
task (torch.Tensor): Task index (only used for multi-task experiments).
|
task (torch.Tensor): Task index (only used for multi-task experiments).
|
||||||
@@ -181,10 +237,8 @@ class TDMPC2:
|
|||||||
Returns:
|
Returns:
|
||||||
float: Loss of the policy update.
|
float: Loss of the policy update.
|
||||||
"""
|
"""
|
||||||
self.pi_optim.zero_grad(set_to_none=True)
|
|
||||||
self.model.track_q_grad(False)
|
|
||||||
_, pis, log_pis, _ = self.model.pi(zs, task)
|
_, pis, log_pis, _ = self.model.pi(zs, task)
|
||||||
qs = self.model.Q(zs, pis, task, return_type='avg')
|
qs = self.model.Q(zs, pis, task, return_type='avg', detach=True)
|
||||||
self.scale.update(qs[0])
|
self.scale.update(qs[0])
|
||||||
qs = self.scale(qs)
|
qs = self.scale(qs)
|
||||||
|
|
||||||
@@ -192,22 +246,23 @@ class TDMPC2:
|
|||||||
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||||
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
|
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
|
||||||
pi_loss.backward()
|
pi_loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
||||||
self.pi_optim.step()
|
self.pi_optim.step()
|
||||||
self.model.track_q_grad(True)
|
# For some reason, cudagraph prefers to see the zero grad after step
|
||||||
|
self.pi_optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
return pi_loss.item()
|
return pi_loss.detach(), pi_grad_norm
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _td_target(self, next_z, reward, task):
|
def _td_target(self, next_z, reward, task):
|
||||||
"""
|
"""
|
||||||
Compute the TD-target from a reward and the observation at the following time step.
|
Compute the TD-target from a reward and the observation at the following time step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
next_z (torch.Tensor): Latent state at the following time step.
|
next_z (torch.Tensor): Latent state at the following time step.
|
||||||
reward (torch.Tensor): Reward at the current time step.
|
reward (torch.Tensor): Reward at the current time step.
|
||||||
task (torch.Tensor): Task index (only used for multi-task experiments).
|
task (torch.Tensor): Task index (only used for multi-task experiments).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: TD-target.
|
torch.Tensor: TD-target.
|
||||||
"""
|
"""
|
||||||
@@ -218,22 +273,28 @@ class TDMPC2:
|
|||||||
def update(self, buffer):
|
def update(self, buffer):
|
||||||
"""
|
"""
|
||||||
Main update function. Corresponds to one iteration of model learning.
|
Main update function. Corresponds to one iteration of model learning.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
buffer (common.buffer.Buffer): Replay buffer.
|
buffer (common.buffer.Buffer): Replay buffer.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Dictionary of training statistics.
|
dict: Dictionary of training statistics.
|
||||||
"""
|
"""
|
||||||
obs, action, reward, task = buffer.sample()
|
with timeit("sample"):
|
||||||
|
obs, action, reward, task = buffer.sample()
|
||||||
|
kwargs = {}
|
||||||
|
if task is not None:
|
||||||
|
kwargs["task"] = task
|
||||||
|
torch.compiler.cudagraph_mark_step_begin()
|
||||||
|
return self._update(obs, action, reward, **kwargs)
|
||||||
|
|
||||||
|
def _update(self, obs, action, reward, task=None):
|
||||||
# Compute targets
|
# Compute targets
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_z = self.model.encode(obs[1:], task)
|
next_z = self.model.encode(obs[1:], task)
|
||||||
td_targets = self._td_target(next_z, reward, task)
|
td_targets = self._td_target(next_z, reward, task)
|
||||||
|
|
||||||
# Prepare for update
|
# Prepare for update
|
||||||
self.optim.zero_grad(set_to_none=True)
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
|
||||||
# Latent rollout
|
# Latent rollout
|
||||||
@@ -241,25 +302,26 @@ class TDMPC2:
|
|||||||
z = self.model.encode(obs[0], task)
|
z = self.model.encode(obs[0], task)
|
||||||
zs[0] = z
|
zs[0] = z
|
||||||
consistency_loss = 0
|
consistency_loss = 0
|
||||||
for t in range(self.cfg.horizon):
|
for t, (_action, _next_z) in enumerate(zip(action.unbind(0), next_z.unbind(0))):
|
||||||
z = self.model.next(z, action[t], task)
|
z = self.model.next(z, _action, task)
|
||||||
consistency_loss += F.mse_loss(z, next_z[t]) * self.cfg.rho**t
|
consistency_loss = consistency_loss + F.mse_loss(z, _next_z) * self.cfg.rho**t
|
||||||
zs[t+1] = z
|
zs[t+1] = z
|
||||||
|
|
||||||
# Predictions
|
# Predictions
|
||||||
_zs = zs[:-1]
|
_zs = zs[:-1]
|
||||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||||
reward_preds = self.model.reward(_zs, action, task)
|
reward_preds = self.model.reward(_zs, action, task)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
reward_loss, value_loss = 0, 0
|
reward_loss, value_loss = 0, 0
|
||||||
for t in range(self.cfg.horizon):
|
for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))):
|
||||||
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
|
reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t
|
||||||
for q in range(self.cfg.num_q):
|
for q, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)):
|
||||||
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
|
value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t
|
||||||
consistency_loss *= (1/self.cfg.horizon)
|
|
||||||
reward_loss *= (1/self.cfg.horizon)
|
consistency_loss = consistency_loss / self.cfg.horizon
|
||||||
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
|
reward_loss = reward_loss / self.cfg.horizon
|
||||||
|
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
||||||
total_loss = (
|
total_loss = (
|
||||||
self.cfg.consistency_coef * consistency_loss +
|
self.cfg.consistency_coef * consistency_loss +
|
||||||
self.cfg.reward_coef * reward_loss +
|
self.cfg.reward_coef * reward_loss +
|
||||||
@@ -270,21 +332,23 @@ class TDMPC2:
|
|||||||
total_loss.backward()
|
total_loss.backward()
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)
|
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
self.optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# Update policy
|
# Update policy
|
||||||
pi_loss = self.update_pi(zs.detach(), task)
|
pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task)
|
||||||
|
|
||||||
# Update target Q-functions
|
# Update target Q-functions
|
||||||
self.model.soft_update_target_Q()
|
self.model.soft_update_target_Q()
|
||||||
|
|
||||||
# Return training statistics
|
# Return training statistics
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
return {
|
return TensorDict({
|
||||||
"consistency_loss": float(consistency_loss.mean().item()),
|
"consistency_loss": consistency_loss,
|
||||||
"reward_loss": float(reward_loss.mean().item()),
|
"reward_loss": reward_loss,
|
||||||
"value_loss": float(value_loss.mean().item()),
|
"value_loss": value_loss,
|
||||||
"pi_loss": pi_loss,
|
"pi_loss": pi_loss,
|
||||||
"total_loss": float(total_loss.mean().item()),
|
"total_loss": total_loss,
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": grad_norm,
|
||||||
"pi_scale": float(self.scale.value),
|
"pi_grad_norm": pi_grad_norm,
|
||||||
}
|
"pi_scale": self.scale.value,
|
||||||
|
}).detach().mean()
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
os.environ['MUJOCO_GL'] = 'egl'
|
os.environ['MUJOCO_GL'] = 'egl'
|
||||||
os.environ['LAZY_LEGACY_OP'] = '0'
|
os.environ['LAZY_LEGACY_OP'] = '0'
|
||||||
|
os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1"
|
||||||
|
os.environ['TORCH_LOGS'] = "+recompiles"
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import torch
|
import torch
|
||||||
@@ -16,9 +18,42 @@ from tdmpc2 import TDMPC2
|
|||||||
from trainer.offline_trainer import OfflineTrainer
|
from trainer.offline_trainer import OfflineTrainer
|
||||||
from trainer.online_trainer import OnlineTrainer
|
from trainer.online_trainer import OnlineTrainer
|
||||||
from common.logger import Logger
|
from common.logger import Logger
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any
|
||||||
|
from omegaconf import OmegaConf
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision('high')
|
||||||
|
|
||||||
|
def cfg_to_dataclass(cfg, frozen=False):
|
||||||
|
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
||||||
|
cfg_dict = OmegaConf.to_container(cfg)
|
||||||
|
fields = []
|
||||||
|
for key, value in cfg_dict.items():
|
||||||
|
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
||||||
|
|
||||||
|
# Create the dataclass
|
||||||
|
dataclass_name = "Config"
|
||||||
|
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
||||||
|
def get(self, val, default=None):
|
||||||
|
return getattr(self, val, default)
|
||||||
|
dataclass.get = get
|
||||||
|
return dataclass()
|
||||||
|
|
||||||
|
def cfg_to_dataclass(cfg, frozen=False):
|
||||||
|
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
||||||
|
cfg_dict = OmegaConf.to_container(cfg)
|
||||||
|
fields = []
|
||||||
|
for key, value in cfg_dict.items():
|
||||||
|
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
||||||
|
|
||||||
|
# Create the dataclass
|
||||||
|
dataclass_name = "Config"
|
||||||
|
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
||||||
|
def get(self, val, default=None):
|
||||||
|
return getattr(self, val, default)
|
||||||
|
dataclass.get = get
|
||||||
|
return dataclass()
|
||||||
|
|
||||||
@hydra.main(config_name='config', config_path='.')
|
@hydra.main(config_name='config', config_path='.')
|
||||||
def train(cfg: dict):
|
def train(cfg: dict):
|
||||||
@@ -47,6 +82,9 @@ def train(cfg: dict):
|
|||||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||||
|
|
||||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||||
|
|
||||||
|
cfg = cfg_to_dataclass(cfg)
|
||||||
|
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
env=make_env(cfg),
|
env=make_env(cfg),
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from time import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tensordict.tensordict import TensorDict
|
from tensordict.tensordict import TensorDict
|
||||||
|
from torchrl._utils import timeit
|
||||||
from trainer.base import Trainer
|
from trainer.base import Trainer
|
||||||
|
|
||||||
|
|
||||||
@@ -57,61 +57,64 @@ class OnlineTrainer(Trainer):
|
|||||||
action = torch.full_like(self.env.rand_act(), float('nan'))
|
action = torch.full_like(self.env.rand_act(), float('nan'))
|
||||||
if reward is None:
|
if reward is None:
|
||||||
reward = torch.tensor(float('nan'))
|
reward = torch.tensor(float('nan'))
|
||||||
td = TensorDict(dict(
|
td = TensorDict(
|
||||||
obs=obs,
|
obs=obs,
|
||||||
action=action.unsqueeze(0),
|
action=action.unsqueeze(0),
|
||||||
reward=reward.unsqueeze(0),
|
reward=reward.unsqueeze(0),
|
||||||
), batch_size=(1,))
|
batch_size=(1,))
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
"""Train a TD-MPC2 agent."""
|
"""Train a TD-MPC2 agent."""
|
||||||
train_metrics, done, eval_next = {}, True, True
|
train_metrics, done, eval_next = {}, True, False
|
||||||
while self._step <= self.cfg.steps:
|
while self._step <= self.cfg.steps:
|
||||||
|
with timeit("global-step"):
|
||||||
|
# Evaluate agent periodically
|
||||||
|
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
|
||||||
|
eval_next = True
|
||||||
|
|
||||||
# Evaluate agent periodically
|
# Reset environment
|
||||||
if self._step % self.cfg.eval_freq == 0:
|
if done or (self._step == self.cfg.seed_steps + 1):
|
||||||
eval_next = True
|
if eval_next:
|
||||||
|
eval_metrics = self.eval()
|
||||||
|
eval_metrics.update(self.common_metrics())
|
||||||
|
self.logger.log(eval_metrics, 'eval')
|
||||||
|
eval_next = False
|
||||||
|
|
||||||
# Reset environment
|
if self._step > 0:
|
||||||
if done:
|
train_metrics.update(
|
||||||
if eval_next:
|
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||||
eval_metrics = self.eval()
|
episode_success=info['success'],
|
||||||
eval_metrics.update(self.common_metrics())
|
)
|
||||||
self.logger.log(eval_metrics, 'eval')
|
train_metrics.update(self.common_metrics())
|
||||||
eval_next = False
|
train_metrics.update(timeit.todict())
|
||||||
|
self.logger.log(train_metrics, 'train')
|
||||||
|
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||||
|
|
||||||
if self._step > 0:
|
obs = self.env.reset()
|
||||||
train_metrics.update(
|
self._tds = [self.to_td(obs)]
|
||||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
|
||||||
episode_success=info['success'],
|
|
||||||
)
|
|
||||||
train_metrics.update(self.common_metrics())
|
|
||||||
self.logger.log(train_metrics, 'train')
|
|
||||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
|
||||||
|
|
||||||
obs = self.env.reset()
|
# Collect experience
|
||||||
self._tds = [self.to_td(obs)]
|
with timeit("act"):
|
||||||
|
if self._step > self.cfg.seed_steps:
|
||||||
|
action = self.agent.act(obs, t0=len(self._tds)==1)
|
||||||
|
else:
|
||||||
|
action = self.env.rand_act()
|
||||||
|
obs, reward, done, info = self.env.step(action)
|
||||||
|
self._tds.append(self.to_td(obs, action, reward))
|
||||||
|
|
||||||
# Collect experience
|
# Update agent
|
||||||
if self._step > self.cfg.seed_steps:
|
if self._step >= self.cfg.seed_steps:
|
||||||
action = self.agent.act(obs, t0=len(self._tds)==1)
|
if self._step == self.cfg.seed_steps:
|
||||||
else:
|
num_updates = self.cfg.seed_steps
|
||||||
action = self.env.rand_act()
|
print('Pretraining agent on seed data...')
|
||||||
obs, reward, done, info = self.env.step(action)
|
else:
|
||||||
self._tds.append(self.to_td(obs, action, reward))
|
num_updates = 1
|
||||||
|
for _ in range(num_updates):
|
||||||
|
with timeit("update"):
|
||||||
|
_train_metrics = self.agent.update(self.buffer)
|
||||||
|
train_metrics.update(_train_metrics)
|
||||||
|
|
||||||
# Update agent
|
self._step += 1
|
||||||
if self._step >= self.cfg.seed_steps:
|
|
||||||
if self._step == self.cfg.seed_steps:
|
|
||||||
num_updates = self.cfg.seed_steps
|
|
||||||
print('Pretraining agent on seed data...')
|
|
||||||
else:
|
|
||||||
num_updates = 1
|
|
||||||
for _ in range(num_updates):
|
|
||||||
_train_metrics = self.agent.update(self.buffer)
|
|
||||||
train_metrics.update(_train_metrics)
|
|
||||||
|
|
||||||
self._step += 1
|
|
||||||
|
|
||||||
self.logger.finish(self.agent)
|
self.logger.finish(self.agent)
|
||||||
|
|||||||
Reference in New Issue
Block a user