merged commits

This commit is contained in:
vmoens
2024-09-25 07:57:26 -07:00
parent 88095e7899
commit 8b731819a6
12 changed files with 406 additions and 271 deletions

View File

@@ -1,56 +1,46 @@
name: tdmpc2 name: graph
channels: channels:
- pytorch-nightly - pytorch-nightly
- nvidia - nvidia
- conda-forge - conda-forge
- defaults - defaults
dependencies: dependencies:
- cudatoolkit=11.7 - glew=2.2.0
- glew=2.1.0 - glib=2.78.4
- glib=2.68.4 - pip=24.0
- pip=21.0 - python=3.9
- python=3.9.0 - pytorch
- pytorch>=2.2.2 - pytorch-cuda=12.4
- torchvision>=0.16.2 - torchvision
- pip: - pip:
- absl-py==2.0.0 - absl-py==2.1.0
- "cython<3"
- dm-control==1.0.8 - dm-control==1.0.8
- glfw==2.7.0
- ffmpeg==1.4 - ffmpeg==1.4
- glfw==2.6.4 - imageio==2.34.1
- imageio-ffmpeg==0.4.9
- h5py==3.11.0
- hydra-core==1.3.2 - hydra-core==1.3.2
- hydra-submitit-launcher==1.2.0 - hydra-submitit-launcher==1.2.0
- imageio==2.33.1 - submitit==1.5.1
- imageio-ffmpeg==0.4.9 - omegaconf==2.3.0
- kornia==0.7.1
- moviepy==1.0.3 - moviepy==1.0.3
- mujoco==2.3.1 - mujoco==2.3.1
- mujoco-py==2.1.2.14 - numpy==1.24.4
- numpy==1.23.5 - tensordict-nightly
- omegaconf==2.3.0 - torchrl-nightly
- open3d==0.18.0 - kornia==0.7.2
- opencv-contrib-python==4.9.0.80
- opencv-python==4.9.0.80
- pandas==2.1.4
- sapien==2.2.1
- submitit==1.5.1
- setuptools==65.5.0
- patchelf==0.17.2.1
- protobuf==4.25.2
- pillow==10.2.0
- pyquaternion==0.9.9
- tensordict-nightly==2024.3.26
- termcolor==2.4.0 - termcolor==2.4.0
- torchrl-nightly==2024.3.26 - tqdm==4.66.4
- transforms3d==0.4.1 - pandas==2.0.3
- trimesh==4.0.9 - wandb==0.17.4
- tqdm==4.66.1 - matplotlib==3.7.5
- wandb==0.16.2 - seaborn==0.13.2
- wheel==0.38.0 - gpustat==1.1.1
#################### ####################
# Gym: # Gym:
# (unmaintained but required for maniskill2/meta-world/myosuite) # (unmaintained but required for maniskill2/meta-world/myosuite)
# - gym==0.21.0 - gym==0.21.0
#################### ####################
# ManiSkill2: # ManiSkill2:
# (requires gym==0.21.0 which occasionally breaks) # (requires gym==0.21.0 which occasionally breaks)

34
requirements.txt Normal file
View File

@@ -0,0 +1,34 @@
absl-py
cython
dm-control
ffmpeg
glfw
hydra-core
hydra-submitit-launcher
imageio
imageio-ffmpeg
kornia
moviepy
mujoco
mujoco-py
numpy<2
omegaconf
open3d
opencv-contrib-python
opencv-python
pandas
sapien
submitit
setuptools
patchelf
protobuf
pillow
pyquaternion
tensordict-nightly
termcolor
torchrl-nightly
transforms3d
trimesh
tqdm
wandb
wheel

View File

@@ -12,7 +12,7 @@ class Buffer():
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._device = torch.device('cuda') self._device = torch.device('cuda:0')
self._capacity = min(cfg.buffer_size, cfg.steps) self._capacity = min(cfg.buffer_size, cfg.steps)
self._sampler = SliceSampler( self._sampler = SliceSampler(
num_slices=self.cfg.batch_size, num_slices=self.cfg.batch_size,
@@ -28,7 +28,7 @@ class Buffer():
def capacity(self): def capacity(self):
"""Return the capacity of the buffer.""" """Return the capacity of the buffer."""
return self._capacity return self._capacity
@property @property
def num_eps(self): def num_eps(self):
"""Return the number of episodes in the buffer.""" """Return the number of episodes in the buffer."""
@@ -41,8 +41,8 @@ class Buffer():
return ReplayBuffer( return ReplayBuffer(
storage=storage, storage=storage,
sampler=self._sampler, sampler=self._sampler,
pin_memory=True, pin_memory=False,
prefetch=1, prefetch=0,
batch_size=self._batch_size, batch_size=self._batch_size,
) )
@@ -58,32 +58,30 @@ class Buffer():
total_bytes = bytes_per_step*self._capacity total_bytes = bytes_per_step*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB') print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory # Heuristic: decide whether to use CUDA or CPU memory
storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu'
print(f'Using {storage_device.upper()} memory for storage.') print(f'Using {storage_device.upper()} memory for storage.')
self._storage_device = torch.device(storage_device)
return self._reserve_buffer( return self._reserve_buffer(
LazyTensorStorage(self._capacity, device=torch.device(storage_device)) LazyTensorStorage(self._capacity, device=self._storage_device)
) )
def _to_device(self, *args, device=None):
if device is None:
device = self._device
return (arg.to(device, non_blocking=True) \
if arg is not None else None for arg in args)
def _prepare_batch(self, td): def _prepare_batch(self, td):
""" """
Prepare a sampled batch for training (post-processing). Prepare a sampled batch for training (post-processing).
Expects `td` to be a TensorDict with batch size TxB. Expects `td` to be a TensorDict with batch size TxB.
""" """
obs = td['obs'] td = td.select("obs", "action", "reward", "task", strict=False).to(self._device, non_blocking=True)
action = td['action'][1:] obs = td.get('obs').contiguous()
reward = td['reward'][1:].unsqueeze(-1) action = td.get('action')[1:].contiguous()
task = td['task'][0] if 'task' in td.keys() else None reward = td.get('reward')[1:].unsqueeze(-1).contiguous()
return self._to_device(obs, action, reward, task) task = td.get('task', None)
if task is not None:
task = task[0].contiguous()
return obs, action, reward, task
def add(self, td): def add(self, td):
"""Add an episode to the buffer.""" """Add an episode to the buffer."""
td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64)
if self._num_eps == 0: if self._num_eps == 0:
self._buffer = self._init(td) self._buffer = self._init(td)
self._buffer.extend(td) self._buffer.extend(td)

View File

@@ -1,8 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from functorch import combine_state_for_ensemble from tensordict import from_modules
from copy import deepcopy
class Ensemble(nn.Module): class Ensemble(nn.Module):
""" """
@@ -11,14 +11,18 @@ class Ensemble(nn.Module):
def __init__(self, modules, **kwargs): def __init__(self, modules, **kwargs):
super().__init__() super().__init__()
modules = nn.ModuleList(modules) # combine_state_for_ensemble causes graph breaks
fn, params, _ = combine_state_for_ensemble(modules) self.params = from_modules(*modules, as_module=True)
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs) with self.params[0].data.to("meta").to_module(modules[0]):
self.params = nn.ParameterList([nn.Parameter(p) for p in params]) self.module = deepcopy(modules[0])
self._repr = str(modules) self._repr = str(modules)
def _call(self, params, *args, **kwargs):
with params.to_module(self.module):
return self.module(*args, **kwargs)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.vmap([p for p in self.params], (), *args, **kwargs) return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
def __repr__(self): def __repr__(self):
return 'Vectorized ' + self._repr return 'Vectorized ' + self._repr
@@ -32,13 +36,13 @@ class ShiftAug(nn.Module):
def __init__(self, pad=3): def __init__(self, pad=3):
super().__init__() super().__init__()
self.pad = pad self.pad = pad
self.padding = tuple([self.pad] * 4)
def forward(self, x): def forward(self, x):
x = x.float() x = x.float()
n, _, h, w = x.size() n, _, h, w = x.size()
assert h == w assert h == w
padding = tuple([self.pad] * 4) x = F.pad(x, self.padding, 'replicate')
x = F.pad(x, padding, 'replicate')
eps = 1.0 / (h + 2 * self.pad) eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h] arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
@@ -59,7 +63,7 @@ class PixelPreprocess(nn.Module):
super().__init__() super().__init__()
def forward(self, x): def forward(self, x):
return x.div_(255.).sub_(0.5) return x.div(255.).sub(0.5)
class SimNorm(nn.Module): class SimNorm(nn.Module):
@@ -67,17 +71,17 @@ class SimNorm(nn.Module):
Simplicial normalization. Simplicial normalization.
Adapted from https://arxiv.org/abs/2204.00616. Adapted from https://arxiv.org/abs/2204.00616.
""" """
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.dim = cfg.simnorm_dim self.dim = cfg.simnorm_dim
def forward(self, x): def forward(self, x):
shp = x.shape shp = x.shape
x = x.view(*shp[:-1], -1, self.dim) x = x.view(*shp[:-1], -1, self.dim)
x = F.softmax(x, dim=-1) x = F.softmax(x, dim=-1)
return x.view(*shp) return x.view(*shp)
def __repr__(self): def __repr__(self):
return f"SimNorm(dim={self.dim})" return f"SimNorm(dim={self.dim})"
@@ -87,18 +91,20 @@ class NormedLinear(nn.Linear):
Linear layer with LayerNorm, activation, and optionally dropout. Linear layer with LayerNorm, activation, and optionally dropout.
""" """
def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs): def __init__(self, *args, dropout=0., act=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.ln = nn.LayerNorm(self.out_features) self.ln = nn.LayerNorm(self.out_features)
if act is None:
act = nn.Mish(inplace=False)
self.act = act self.act = act
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None self.dropout = nn.Dropout(dropout, inplace=False) if dropout else None
def forward(self, x): def forward(self, x):
x = super().forward(x) x = super().forward(x)
if self.dropout: if self.dropout:
x = self.dropout(x) x = self.dropout(x)
return self.act(self.ln(x)) return self.act(self.ln(x))
def __repr__(self): def __repr__(self):
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else "" repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
return f"NormedLinear(in_features={self.in_features}, "\ return f"NormedLinear(in_features={self.in_features}, "\
@@ -130,9 +136,9 @@ def conv(in_shape, num_channels, act=None):
assert in_shape[-1] == 64 # assumes rgb observations to be 64x64 assert in_shape[-1] == 64 # assumes rgb observations to be 64x64
layers = [ layers = [
ShiftAug(), PixelPreprocess(), ShiftAug(), PixelPreprocess(),
nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=True), nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=False),
nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=True), nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=False),
nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=True), nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=False),
nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()] nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()]
if act: if act:
layers.append(act) layers.append(act)

View File

@@ -1,11 +1,11 @@
import dataclasses
import os import os
import datetime import datetime
import re import re
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from termcolor import colored from termcolor import colored
from omegaconf import OmegaConf from torchrl._utils import timeit
from common import TASK_SET from common import TASK_SET
@@ -133,7 +133,7 @@ class Logger:
group=self._group, group=self._group,
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"], tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
dir=self._log_dir, dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True), config=dataclasses.asdict(cfg),
) )
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
self._wandb = wandb self._wandb = wandb
@@ -238,3 +238,5 @@ class Logger:
self._log_dir / "eval.csv", header=keys, index=None self._log_dir / "eval.csv", header=keys, index=None
) )
self._print(d, category) self._print(d, category)
timeit.print()
timeit.erase()

View File

@@ -9,30 +9,30 @@ def soft_ce(pred, target, cfg):
return -(target * pred).sum(-1, keepdim=True) return -(target * pred).sum(-1, keepdim=True)
@torch.jit.script
def log_std(x, low, dif): def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1) return low + 0.5 * dif * (torch.tanh(x) + 1)
@torch.jit.script
def _gaussian_residual(eps, log_std): def _gaussian_residual(eps, log_std):
return -0.5 * eps.pow(2) - log_std return -0.5 * eps.pow(2) - log_std
@torch.jit.script
def _gaussian_logprob(residual): def _gaussian_logprob(residual):
return residual - 0.5 * torch.log(2 * torch.pi) log2pi = 1.8378770351409912
return residual - 0.5 * log2pi
def gaussian_logprob(eps, log_std, size=None): def gaussian_logprob(eps, log_std, size=None):
"""Compute Gaussian log probability.""" """Compute Gaussian log probability."""
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
if size is None: if size is None:
size = eps.size(-1) size = eps.shape[-1]
return _gaussian_logprob(residual) * size return _gaussian_logprob(residual) * size
@torch.jit.script
def _squash(pi): def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
@@ -45,7 +45,7 @@ def squash(mu, pi, log_pi):
return mu, pi, log_pi return mu, pi, log_pi
@torch.jit.script
def symlog(x): def symlog(x):
""" """
Symmetric logarithmic function. Symmetric logarithmic function.
@@ -54,7 +54,7 @@ def symlog(x):
return torch.sign(x) * torch.log(1 + torch.abs(x)) return torch.sign(x) * torch.log(1 + torch.abs(x))
@torch.jit.script
def symexp(x): def symexp(x):
""" """
Symmetric exponential function. Symmetric exponential function.
@@ -70,26 +70,32 @@ def two_hot(x, cfg):
elif cfg.num_bins == 1: elif cfg.num_bins == 1:
return symlog(x) return symlog(x)
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1) x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1)
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size)
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx).unsqueeze(-1)
soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) soft_two_hot = torch.zeros(x.shape[0], cfg.num_bins, device=x.device, dtype=x.dtype)
soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset) bin_idx = bin_idx.long()
soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) soft_two_hot = soft_two_hot.scatter(1, bin_idx.unsqueeze(1), 1 - bin_offset)
soft_two_hot = soft_two_hot.scatter(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset)
return soft_two_hot return soft_two_hot
DREG_BINS = None
def two_hot_inv(x, cfg): def two_hot_inv(x, cfg):
"""Converts a batch of soft two-hot encoded vectors to scalars.""" """Converts a batch of soft two-hot encoded vectors to scalars."""
global DREG_BINS
if cfg.num_bins == 0: if cfg.num_bins == 0:
return x return x
elif cfg.num_bins == 1: elif cfg.num_bins == 1:
return symexp(x) return symexp(x)
if DREG_BINS is None: dreg_bins = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device, dtype=x.dtype)
DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device)
x = F.softmax(x, dim=-1) x = F.softmax(x, dim=-1)
x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True) x = torch.sum(x * dreg_bins, dim=-1, keepdim=True)
return symexp(x) return symexp(x)
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
logits = p.log()
# Generate Gumbel noise
gumbels = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
) # ~Gumbel(0,1)
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
return y_soft.argmax(-1)

View File

@@ -1,48 +1,49 @@
import torch import torch
from torch.nn import Buffer
class RunningScale(torch.nn.Module):
class RunningScale:
"""Running trimmed scale estimator.""" """Running trimmed scale estimator."""
def __init__(self, cfg): def __init__(self, cfg):
super().__init__()
self.cfg = cfg self.cfg = cfg
self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda')) self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda')))
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')) self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')))
def state_dict(self): def state_dict(self):
return dict(value=self._value, percentiles=self._percentiles) return dict(value=self.value, percentiles=self._percentiles)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self._value.data.copy_(state_dict['value']) self.value.copy_(state_dict['value'])
self._percentiles.data.copy_(state_dict['percentiles']) self._percentiles.copy_(state_dict['percentiles'])
@property def _positions(self, x_shape):
def value(self): positions = self._percentiles * (x_shape-1) / 100
return self._value.cpu().item() floored = torch.floor(positions)
ceiled = floored + 1
ceiled = torch.where(ceiled > x_shape - 1, x_shape - 1, ceiled)
weight_ceiled = positions-floored
weight_floored = 1.0 - weight_ceiled
return floored.long(), ceiled.long(), weight_floored.unsqueeze(1), weight_ceiled.unsqueeze(1)
def _percentile(self, x): def _percentile(self, x):
x_dtype, x_shape = x.dtype, x.shape x_dtype, x_shape = x.dtype, x.shape
x = x.view(x.shape[0], -1) x = x.flatten(1, x.ndim-1)
in_sorted, _ = torch.sort(x, dim=0) in_sorted = torch.sort(x, dim=0).values
positions = self._percentiles * (x.shape[0]-1) / 100 floored, ceiled, weight_floored, weight_ceiled = self._positions(x.shape[0])
floored = torch.floor(positions) d0 = in_sorted[floored] * weight_floored
ceiled = floored + 1 d1 = in_sorted[ceiled] * weight_ceiled
ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1 return (d0+d1).reshape(-1, *x_shape[1:]).to(x_dtype)
weight_ceiled = positions-floored
weight_floored = 1.0 - weight_ceiled
d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
return (d0+d1).view(-1, *x_shape[1:]).type(x_dtype)
def update(self, x): def update(self, x):
percentiles = self._percentile(x.detach()) percentiles = self._percentile(x.detach())
value = torch.clamp(percentiles[1] - percentiles[0], min=1.) value = torch.clamp(percentiles[1] - percentiles[0], min=1.)
self._value.data.lerp_(value, self.cfg.tau) self.value.data.lerp_(value, self.cfg.tau)
def __call__(self, x, update=False): def forward(self, x, update=False):
if update: if update:
self.update(x) self.update(x)
return x * (1/self.value) return x / self.value
def __repr__(self): def __repr__(self):
return f'RunningScale(S: {self.value})' return f'RunningScale(S: {self.value})'

View File

@@ -5,7 +5,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from common import layers, math, init from common import layers, math, init
from tensordict import TensorDict
from tensordict.nn import TensorDictParams
class WorldModel(nn.Module): class WorldModel(nn.Module):
""" """
@@ -18,7 +19,7 @@ class WorldModel(nn.Module):
self.cfg = cfg self.cfg = cfg
if cfg.multitask: if cfg.multitask:
self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1) self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1)
self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim) self.register_buffer("_action_masks", torch.zeros(len(cfg.tasks), cfg.action_dim))
for i in range(len(cfg.tasks)): for i in range(len(cfg.tasks)):
self._action_masks[i, :cfg.action_dims[i]] = 1. self._action_masks[i, :cfg.action_dims[i]] = 1.
self._encoder = layers.enc(cfg) self._encoder = layers.enc(cfg)
@@ -27,26 +28,35 @@ class WorldModel(nn.Module):
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init) self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params[-2]]) init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
self.log_std_min = torch.tensor(cfg.log_std_min) self.register_buffer("log_std_min", torch.tensor(cfg.log_std_min))
self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min self.register_buffer("log_std_dif", torch.tensor(cfg.log_std_max) - self.log_std_min)
self.init()
def init(self):
# Create params
self._detach_Qs_params = TensorDictParams(self._Qs.params.data, no_convert=True)
self._target_Qs_params = TensorDictParams(self._Qs.params.data.clone(), no_convert=True)
# Create modules
with self._detach_Qs_params.data.to("meta").to_module(self._Qs.module):
self._detach_Qs = deepcopy(self._Qs)
self._target_Qs = deepcopy(self._Qs)
# Assign params to modules
self._detach_Qs.params = self._detach_Qs_params
self._target_Qs.params = self._target_Qs_params
@property @property
def total_params(self): def total_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad) return sum(p.numel() for p in self.parameters() if p.requires_grad)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
"""
Overriding `to` method to also move additional tensors to device.
"""
super().to(*args, **kwargs) super().to(*args, **kwargs)
if self.cfg.multitask: self.init()
self._action_masks = self._action_masks.to(*args, **kwargs)
self.log_std_min = self.log_std_min.to(*args, **kwargs)
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
return self return self
def train(self, mode=True): def train(self, mode=True):
""" """
Overriding `train` method to keep target Q-networks in eval mode. Overriding `train` method to keep target Q-networks in eval mode.
@@ -55,26 +65,12 @@ class WorldModel(nn.Module):
self._target_Qs.train(False) self._target_Qs.train(False)
return self return self
def track_q_grad(self, mode=True):
"""
Enables/disables gradient tracking of Q-networks.
Avoids unnecessary computation during policy optimization.
This method also enables/disables gradients for task embeddings.
"""
for p in self._Qs.parameters():
p.requires_grad_(mode)
if self.cfg.multitask:
for p in self._task_emb.parameters():
p.requires_grad_(mode)
def soft_update_target_Q(self): def soft_update_target_Q(self):
""" """
Soft-update target Q-networks using Polyak averaging. Soft-update target Q-networks using Polyak averaging.
""" """
with torch.no_grad(): self._target_Qs_params.lerp_(self._detach_Qs_params, self.cfg.tau)
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()):
p_target.data.lerp_(p.data, self.cfg.tau)
def task_emb(self, x, task): def task_emb(self, x, task):
""" """
Continuous task embedding for multi-task experiments. Continuous task embedding for multi-task experiments.
@@ -109,7 +105,7 @@ class WorldModel(nn.Module):
z = self.task_emb(z, task) z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
return self._dynamics(z) return self._dynamics(z)
def reward(self, z, a, task): def reward(self, z, a, task):
""" """
Predicts instantaneous (single-step) reward. Predicts instantaneous (single-step) reward.
@@ -147,7 +143,7 @@ class WorldModel(nn.Module):
return mu, pi, log_pi, log_std return mu, pi, log_pi, log_std
def Q(self, z, a, task, return_type='min', target=False): def Q(self, z, a, task, return_type='min', target=False, detach=False):
""" """
Predict state-action value. Predict state-action value.
`return_type` can be one of [`min`, `avg`, `all`]: `return_type` can be one of [`min`, `avg`, `all`]:
@@ -160,13 +156,21 @@ class WorldModel(nn.Module):
if self.cfg.multitask: if self.cfg.multitask:
z = self.task_emb(z, task) z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
out = (self._target_Qs if target else self._Qs)(z) if target:
qnet = self._target_Qs
elif detach:
qnet = self._detach_Qs
else:
qnet = self._Qs
out = qnet(z)
if return_type == 'all': if return_type == 'all':
return out return out
Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)] qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
Q1, Q2 = math.two_hot_inv(Q1, self.cfg), math.two_hot_inv(Q2, self.cfg) Q = math.two_hot_inv(out[qidx], self.cfg)
return torch.min(Q1, Q2) if return_type == 'min' else (Q1 + Q2) / 2 if return_type == "min":
return Q.min(0).values
return Q.sum(0) / 2

View File

@@ -86,3 +86,7 @@ action_dims: ???
episode_lengths: ??? episode_lengths: ???
seed_steps: ??? seed_steps: ???
bin_size: ??? bin_size: ???
# compile
compile: False
cudagraphs: False

View File

@@ -1,13 +1,18 @@
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import functools
from torchrl._utils import timeit
from common import math from common import math
from common.scale import RunningScale from common.scale import RunningScale
from common.world_model import WorldModel from common.world_model import WorldModel
from tensordict.nn import CudaGraphModule
from tensordict import TensorDict
CG_WARMUP = 1000
class TDMPC2: class TDMPC2(torch.nn.Module):
""" """
TD-MPC2 agent. Implements training + inference. TD-MPC2 agent. Implements training + inference.
Can be used for both single-task and multi-task experiments, Can be used for both single-task and multi-task experiments,
@@ -15,23 +20,71 @@ class TDMPC2:
""" """
def __init__(self, cfg): def __init__(self, cfg):
super().__init__()
self.cfg = cfg self.cfg = cfg
self.device = torch.device('cuda')
self.device = torch.device('cuda:0')
self.model = WorldModel(cfg).to(self.device) self.model = WorldModel(cfg).to(self.device)
capturable = True
self.optim = torch.optim.Adam([ self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()}, {'params': self.model._dynamics.parameters()},
{'params': self.model._reward.parameters()}, {'params': self.model._reward.parameters()},
{'params': self.model._Qs.parameters()}, {'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []} {'params': self.model._task_emb.parameters() if self.cfg.multitask else []
], lr=self.cfg.lr) }
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5) ], lr=self.cfg.lr, capturable=capturable)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=capturable)
self.model.eval() self.model.eval()
self.scale = RunningScale(cfg) self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
self.discount = torch.tensor( self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda' [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile:
mode = None if cfg.cudagraphs else "reduce-overhead"
print('compiling - update')
self._update = torch.compile(self._update, mode=mode)
if cfg.cudagraphs:
print('cudagraphs - update')
self._update = CudaGraphModule(self._update, warmup=CG_WARMUP)
@property
def plan(self):
_plan_val = getattr(self, "_plan_val", None)
if _plan_val is not None:
return _plan_val
if self.cfg.cudagraphs:
print('cudagraphs - plan')
self._plan_dict = {
(True, True): functools.partial(self._plan, t0=True, eval_mode=True),
(False, True): functools.partial(self._plan, t0=False, eval_mode=True),
(True, False): functools.partial(self._plan, t0=True, eval_mode=False),
(False, False): functools.partial(self._plan, t0=False, eval_mode=False),
}
if self.cfg.compile:
print('compiling - plan')
mode = None
self._plan_dict = {k: torch.compile(func, mode=mode) for k, func in self._plan_dict.items()}
self._plan_dict = {k: CudaGraphModule(func, warmup=CG_WARMUP) for k, func in self._plan_dict.items()}
def plan(obs, t0=False, eval_mode=False, task=None):
if task is not None:
kwargs = {"task": task}
else:
kwargs = {}
torch.compiler.cudagraph_mark_step_begin()
return self._plan_dict[(t0, eval_mode)](obs=obs, **kwargs)
elif self.cfg.compile:
plan = torch.compile(self._plan, mode="reduce-overhead")
else:
plan = self._plan
self._plan_val = plan
return self._plan_val
def _get_discount(self, episode_length): def _get_discount(self, episode_length):
""" """
@@ -51,7 +104,7 @@ class TDMPC2:
def save(self, fp): def save(self, fp):
""" """
Save state dict of the agent to filepath. Save state dict of the agent to filepath.
Args: Args:
fp (str): Filepath to save state dict to. fp (str): Filepath to save state dict to.
""" """
@@ -60,7 +113,7 @@ class TDMPC2:
def load(self, fp): def load(self, fp):
""" """
Load a saved state dict from filepath (or dictionary) into current agent. Load a saved state dict from filepath (or dictionary) into current agent.
Args: Args:
fp (str or dict): Filepath or state dict to load. fp (str or dict): Filepath or state dict to load.
""" """
@@ -71,23 +124,23 @@ class TDMPC2:
def act(self, obs, t0=False, eval_mode=False, task=None): def act(self, obs, t0=False, eval_mode=False, task=None):
""" """
Select an action by planning in the latent space of the world model. Select an action by planning in the latent space of the world model.
Args: Args:
obs (torch.Tensor): Observation from the environment. obs (torch.Tensor): Observation from the environment.
t0 (bool): Whether this is the first observation in the episode. t0 (bool): Whether this is the first observation in the episode.
eval_mode (bool): Whether to use the mean of the action distribution. eval_mode (bool): Whether to use the mean of the action distribution.
task (int): Task index (only used for multi-task experiments). task (int): Task index (only used for multi-task experiments).
Returns: Returns:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
obs = obs.to(self.device, non_blocking=True).unsqueeze(0) obs = obs.to(self.device, non_blocking=True).unsqueeze(0)
if task is not None: if task is not None:
task = torch.tensor([task], device=self.device) task = torch.tensor([task], device=self.device)
z = self.model.encode(obs, task)
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
else: else:
z = self.model.encode(obs, task)
a = self.model.pi(z, task)[int(not eval_mode)][0] a = self.model.pi(z, task)[int(not eval_mode)][0]
return a.cpu() return a.cpu()
@@ -98,15 +151,16 @@ class TDMPC2:
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task) z = self.model.next(z, actions[t], task)
G += discount * reward G = G + discount * reward
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
@torch.no_grad() @torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None): def _plan(self, obs, t0=False, eval_mode=False, task=None):
""" """
Plan a sequence of actions using the learned world model. Plan a sequence of actions using the learned world model.
Args: Args:
z (torch.Tensor): Latent state from which to plan. z (torch.Tensor): Latent state from which to plan.
t0 (bool): Whether this is the first observation in the episode. t0 (bool): Whether this is the first observation in the episode.
@@ -115,8 +169,9 @@ class TDMPC2:
Returns: Returns:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
# Sample policy trajectories # Sample policy trajectories
z = self.model.encode(obs, task)
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1) _z = z.repeat(self.cfg.num_pi_trajs, 1)
@@ -128,52 +183,53 @@ class TDMPC2:
# Initialize state and parameters # Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1) z = z.repeat(self.cfg.num_samples, 1)
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device) std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
if not t0: if not t0:
mean[:-1] = self._prev_mean[1:] mean[:-1] = self._prev_mean[1:]
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = pi_actions actions[:, :self.cfg.num_pi_trajs] = pi_actions
# Iterate MPPI # Iterate MPPI
for _ in range(self.cfg.iterations): for _ in range(self.cfg.iterations):
# Sample actions # Sample actions
actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \ r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \ actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r
.clamp(-1, 1) actions_sample = actions_sample.clamp(-1, 1)
actions[:, self.cfg.num_pi_trajs:] = actions_sample
if self.cfg.multitask: if self.cfg.multitask:
actions = actions * self.model._action_masks[task] actions = actions * self.model._action_masks[task]
# Compute elite actions # Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num_(0) value = self._estimate_value(z, actions, task).nan_to_num(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Update parameters # Update parameters
max_value = elite_value.max(0)[0] max_value = elite_value.max(0).values
score = torch.exp(self.cfg.temperature*(elite_value - max_value)) score = torch.exp(self.cfg.temperature*(elite_value - max_value))
score /= score.sum(0) score = score / score.sum(0)
mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \ std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
.clamp_(self.cfg.min_std, self.cfg.max_std) std = std.clamp(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask: if self.cfg.multitask:
mean = mean * self.model._action_masks[task] mean = mean * self.model._action_masks[task]
std = std * self.model._action_masks[task] std = std * self.model._action_masks[task]
# Select action # Select action
score = score.squeeze(1).cpu().numpy() rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
self._prev_mean = mean
a, std = actions[0], std[0] a, std = actions[0], std[0]
if not eval_mode: if not eval_mode:
a += std * torch.randn(self.cfg.action_dim, device=std.device) a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
return a.clamp_(-1, 1) self._prev_mean.copy_(mean)
return a.clamp(-1, 1)
def update_pi(self, zs, task): def update_pi(self, zs, task):
""" """
Update policy using a sequence of latent states. Update policy using a sequence of latent states.
Args: Args:
zs (torch.Tensor): Sequence of latent states. zs (torch.Tensor): Sequence of latent states.
task (torch.Tensor): Task index (only used for multi-task experiments). task (torch.Tensor): Task index (only used for multi-task experiments).
@@ -181,10 +237,8 @@ class TDMPC2:
Returns: Returns:
float: Loss of the policy update. float: Loss of the policy update.
""" """
self.pi_optim.zero_grad(set_to_none=True)
self.model.track_q_grad(False)
_, pis, log_pis, _ = self.model.pi(zs, task) _, pis, log_pis, _ = self.model.pi(zs, task)
qs = self.model.Q(zs, pis, task, return_type='avg') qs = self.model.Q(zs, pis, task, return_type='avg', detach=True)
self.scale.update(qs[0]) self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
@@ -192,22 +246,23 @@ class TDMPC2:
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
pi_loss.backward() pi_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step() self.pi_optim.step()
self.model.track_q_grad(True) # For some reason, cudagraph prefers to see the zero grad after step
self.pi_optim.zero_grad(set_to_none=True)
return pi_loss.item() return pi_loss.detach(), pi_grad_norm
@torch.no_grad() @torch.no_grad()
def _td_target(self, next_z, reward, task): def _td_target(self, next_z, reward, task):
""" """
Compute the TD-target from a reward and the observation at the following time step. Compute the TD-target from a reward and the observation at the following time step.
Args: Args:
next_z (torch.Tensor): Latent state at the following time step. next_z (torch.Tensor): Latent state at the following time step.
reward (torch.Tensor): Reward at the current time step. reward (torch.Tensor): Reward at the current time step.
task (torch.Tensor): Task index (only used for multi-task experiments). task (torch.Tensor): Task index (only used for multi-task experiments).
Returns: Returns:
torch.Tensor: TD-target. torch.Tensor: TD-target.
""" """
@@ -218,22 +273,28 @@ class TDMPC2:
def update(self, buffer): def update(self, buffer):
""" """
Main update function. Corresponds to one iteration of model learning. Main update function. Corresponds to one iteration of model learning.
Args: Args:
buffer (common.buffer.Buffer): Replay buffer. buffer (common.buffer.Buffer): Replay buffer.
Returns: Returns:
dict: Dictionary of training statistics. dict: Dictionary of training statistics.
""" """
obs, action, reward, task = buffer.sample() with timeit("sample"):
obs, action, reward, task = buffer.sample()
kwargs = {}
if task is not None:
kwargs["task"] = task
torch.compiler.cudagraph_mark_step_begin()
return self._update(obs, action, reward, **kwargs)
def _update(self, obs, action, reward, task=None):
# Compute targets # Compute targets
with torch.no_grad(): with torch.no_grad():
next_z = self.model.encode(obs[1:], task) next_z = self.model.encode(obs[1:], task)
td_targets = self._td_target(next_z, reward, task) td_targets = self._td_target(next_z, reward, task)
# Prepare for update # Prepare for update
self.optim.zero_grad(set_to_none=True)
self.model.train() self.model.train()
# Latent rollout # Latent rollout
@@ -241,25 +302,26 @@ class TDMPC2:
z = self.model.encode(obs[0], task) z = self.model.encode(obs[0], task)
zs[0] = z zs[0] = z
consistency_loss = 0 consistency_loss = 0
for t in range(self.cfg.horizon): for t, (_action, _next_z) in enumerate(zip(action.unbind(0), next_z.unbind(0))):
z = self.model.next(z, action[t], task) z = self.model.next(z, _action, task)
consistency_loss += F.mse_loss(z, next_z[t]) * self.cfg.rho**t consistency_loss = consistency_loss + F.mse_loss(z, _next_z) * self.cfg.rho**t
zs[t+1] = z zs[t+1] = z
# Predictions # Predictions
_zs = zs[:-1] _zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all') qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task) reward_preds = self.model.reward(_zs, action, task)
# Compute losses # Compute losses
reward_loss, value_loss = 0, 0 reward_loss, value_loss = 0, 0
for t in range(self.cfg.horizon): for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))):
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t
for q in range(self.cfg.num_q): for q, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)):
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t
consistency_loss *= (1/self.cfg.horizon)
reward_loss *= (1/self.cfg.horizon) consistency_loss = consistency_loss / self.cfg.horizon
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q)) reward_loss = reward_loss / self.cfg.horizon
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
total_loss = ( total_loss = (
self.cfg.consistency_coef * consistency_loss + self.cfg.consistency_coef * consistency_loss +
self.cfg.reward_coef * reward_loss + self.cfg.reward_coef * reward_loss +
@@ -270,21 +332,23 @@ class TDMPC2:
total_loss.backward() total_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm) grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm)
self.optim.step() self.optim.step()
self.optim.zero_grad(set_to_none=True)
# Update policy # Update policy
pi_loss = self.update_pi(zs.detach(), task) pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task)
# Update target Q-functions # Update target Q-functions
self.model.soft_update_target_Q() self.model.soft_update_target_Q()
# Return training statistics # Return training statistics
self.model.eval() self.model.eval()
return { return TensorDict({
"consistency_loss": float(consistency_loss.mean().item()), "consistency_loss": consistency_loss,
"reward_loss": float(reward_loss.mean().item()), "reward_loss": reward_loss,
"value_loss": float(value_loss.mean().item()), "value_loss": value_loss,
"pi_loss": pi_loss, "pi_loss": pi_loss,
"total_loss": float(total_loss.mean().item()), "total_loss": total_loss,
"grad_norm": float(grad_norm), "grad_norm": grad_norm,
"pi_scale": float(self.scale.value), "pi_grad_norm": pi_grad_norm,
} "pi_scale": self.scale.value,
}).detach().mean()

View File

@@ -1,6 +1,8 @@
import os import os
os.environ['MUJOCO_GL'] = 'egl' os.environ['MUJOCO_GL'] = 'egl'
os.environ['LAZY_LEGACY_OP'] = '0' os.environ['LAZY_LEGACY_OP'] = '0'
os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1"
os.environ['TORCH_LOGS'] = "+recompiles"
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import torch import torch
@@ -16,9 +18,27 @@ from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer from trainer.offline_trainer import OfflineTrainer
from trainer.online_trainer import OnlineTrainer from trainer.online_trainer import OnlineTrainer
from common.logger import Logger from common.logger import Logger
import dataclasses
from typing import Any
from omegaconf import OmegaConf
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
def cfg_to_dataclass(cfg, frozen=False):
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
cfg_dict = OmegaConf.to_container(cfg)
fields = []
for key, value in cfg_dict.items():
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
# Create the dataclass
dataclass_name = "Config"
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
def get(self, val, default=None):
return getattr(self, val, default)
dataclass.get = get
return dataclass()
@hydra.main(config_name='config', config_path='.') @hydra.main(config_name='config', config_path='.')
def train(cfg: dict): def train(cfg: dict):
@@ -47,6 +67,9 @@ def train(cfg: dict):
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
cfg = cfg_to_dataclass(cfg)
trainer = trainer_cls( trainer = trainer_cls(
cfg=cfg, cfg=cfg,
env=make_env(cfg), env=make_env(cfg),

View File

@@ -3,7 +3,7 @@ from time import time
import numpy as np import numpy as np
import torch import torch
from tensordict.tensordict import TensorDict from tensordict.tensordict import TensorDict
from torchrl._utils import timeit
from trainer.base import Trainer from trainer.base import Trainer
@@ -57,61 +57,64 @@ class OnlineTrainer(Trainer):
action = torch.full_like(self.env.rand_act(), float('nan')) action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan'))
td = TensorDict(dict( td = TensorDict(
obs=obs, obs=obs,
action=action.unsqueeze(0), action=action.unsqueeze(0),
reward=reward.unsqueeze(0), reward=reward.unsqueeze(0),
), batch_size=(1,)) batch_size=(1,))
return td return td
def train(self): def train(self):
"""Train a TD-MPC2 agent.""" """Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, True train_metrics, done, eval_next = {}, True, False
while self._step <= self.cfg.steps: while self._step <= self.cfg.steps:
with timeit("global-step"):
# Evaluate agent periodically
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
eval_next = True
# Evaluate agent periodically # Reset environment
if self._step % self.cfg.eval_freq == 0: if done or (self._step == self.cfg.seed_steps + 1):
eval_next = True if eval_next:
eval_metrics = self.eval()
eval_metrics.update(self.common_metrics())
self.logger.log(eval_metrics, 'eval')
eval_next = False
# Reset environment if self._step > 0:
if done: train_metrics.update(
if eval_next: episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
eval_metrics = self.eval() episode_success=info['success'],
eval_metrics.update(self.common_metrics()) )
self.logger.log(eval_metrics, 'eval') train_metrics.update(self.common_metrics())
eval_next = False train_metrics.update(timeit.todict())
self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds))
if self._step > 0: obs = self.env.reset()
train_metrics.update( self._tds = [self.to_td(obs)]
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'],
)
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset() # Collect experience
self._tds = [self.to_td(obs)] with timeit("act"):
if self._step > self.cfg.seed_steps:
action = self.agent.act(obs, t0=len(self._tds)==1)
else:
action = self.env.rand_act()
obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward))
# Collect experience # Update agent
if self._step > self.cfg.seed_steps: if self._step >= self.cfg.seed_steps:
action = self.agent.act(obs, t0=len(self._tds)==1) if self._step == self.cfg.seed_steps:
else: num_updates = self.cfg.seed_steps
action = self.env.rand_act() print('Pretraining agent on seed data...')
obs, reward, done, info = self.env.step(action) else:
self._tds.append(self.to_td(obs, action, reward)) num_updates = 1
for _ in range(num_updates):
with timeit("update"):
_train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics)
# Update agent self._step += 1
if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps
print('Pretraining agent on seed data...')
else:
num_updates = 1
for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics)
self._step += 1
self.logger.finish(self.agent) self.logger.finish(self.agent)