From 603b67ce66135b3165f97b4e57a85cb8dd05ecc7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 25 Sep 2024 07:57:26 -0700 Subject: [PATCH 01/12] merged commits --- docker/environment.yaml | 62 ++++------ requirements.txt | 34 ++++++ tdmpc2/common/buffer.py | 34 +++--- tdmpc2/common/layers.py | 44 ++++--- tdmpc2/common/logger.py | 8 +- tdmpc2/common/math.py | 46 ++++--- tdmpc2/common/scale.py | 47 +++---- tdmpc2/common/world_model.py | 80 ++++++------ tdmpc2/config.yaml | 4 + tdmpc2/tdmpc2.py | 202 ++++++++++++++++++++----------- tdmpc2/train.py | 24 +++- tdmpc2/trainer/online_trainer.py | 89 +++++++------- 12 files changed, 406 insertions(+), 268 deletions(-) create mode 100644 requirements.txt diff --git a/docker/environment.yaml b/docker/environment.yaml index 3081b8d..857c81a 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -1,56 +1,46 @@ -name: tdmpc2 +name: graph channels: - pytorch-nightly - nvidia - conda-forge - defaults dependencies: - - cudatoolkit=11.7 - - glew=2.1.0 - - glib=2.68.4 - - pip=21.0 - - python=3.9.0 - - pytorch>=2.2.2 - - torchvision>=0.16.2 + - glew=2.2.0 + - glib=2.78.4 + - pip=24.0 + - python=3.9 + - pytorch + - pytorch-cuda=12.4 + - torchvision - pip: - - absl-py==2.0.0 - - "cython<3" + - absl-py==2.1.0 - dm-control==1.0.8 + - glfw==2.7.0 - ffmpeg==1.4 - - glfw==2.6.4 + - imageio==2.34.1 + - imageio-ffmpeg==0.4.9 + - h5py==3.11.0 - hydra-core==1.3.2 - hydra-submitit-launcher==1.2.0 - - imageio==2.33.1 - - imageio-ffmpeg==0.4.9 - - kornia==0.7.1 + - submitit==1.5.1 + - omegaconf==2.3.0 - moviepy==1.0.3 - mujoco==2.3.1 - - mujoco-py==2.1.2.14 - - numpy==1.23.5 - - omegaconf==2.3.0 - - open3d==0.18.0 - - opencv-contrib-python==4.9.0.80 - - opencv-python==4.9.0.80 - - pandas==2.1.4 - - sapien==2.2.1 - - submitit==1.5.1 - - setuptools==65.5.0 - - patchelf==0.17.2.1 - - protobuf==4.25.2 - - pillow==10.2.0 - - pyquaternion==0.9.9 - - tensordict-nightly==2024.3.26 + - numpy==1.24.4 + - tensordict-nightly + - torchrl-nightly + - kornia==0.7.2 - termcolor==2.4.0 - - torchrl-nightly==2024.3.26 - - transforms3d==0.4.1 - - trimesh==4.0.9 - - tqdm==4.66.1 - - wandb==0.16.2 - - wheel==0.38.0 + - tqdm==4.66.4 + - pandas==2.0.3 + - wandb==0.17.4 + - matplotlib==3.7.5 + - seaborn==0.13.2 + - gpustat==1.1.1 #################### # Gym: # (unmaintained but required for maniskill2/meta-world/myosuite) - # - gym==0.21.0 + - gym==0.21.0 #################### # ManiSkill2: # (requires gym==0.21.0 which occasionally breaks) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f403b3e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,34 @@ +absl-py +cython +dm-control +ffmpeg +glfw +hydra-core +hydra-submitit-launcher +imageio +imageio-ffmpeg +kornia +moviepy +mujoco +mujoco-py +numpy<2 +omegaconf +open3d +opencv-contrib-python +opencv-python +pandas +sapien +submitit +setuptools +patchelf +protobuf +pillow +pyquaternion +tensordict-nightly +termcolor +torchrl-nightly +transforms3d +trimesh +tqdm +wandb +wheel diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index e82f2fe..3ff5b28 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -12,7 +12,7 @@ class Buffer(): def __init__(self, cfg): self.cfg = cfg - self._device = torch.device('cuda') + self._device = torch.device('cuda:0') self._capacity = min(cfg.buffer_size, cfg.steps) self._sampler = SliceSampler( num_slices=self.cfg.batch_size, @@ -28,7 +28,7 @@ class Buffer(): def capacity(self): """Return the capacity of the buffer.""" return self._capacity - + @property def num_eps(self): """Return the number of episodes in the buffer.""" @@ -41,8 +41,8 @@ class Buffer(): return ReplayBuffer( storage=storage, sampler=self._sampler, - pin_memory=True, - prefetch=1, + pin_memory=False, + prefetch=0, batch_size=self._batch_size, ) @@ -58,32 +58,30 @@ class Buffer(): total_bytes = bytes_per_step*self._capacity print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory - storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' + storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu' print(f'Using {storage_device.upper()} memory for storage.') + self._storage_device = torch.device(storage_device) return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device(storage_device)) + LazyTensorStorage(self._capacity, device=self._storage_device) ) - def _to_device(self, *args, device=None): - if device is None: - device = self._device - return (arg.to(device, non_blocking=True) \ - if arg is not None else None for arg in args) - def _prepare_batch(self, td): """ Prepare a sampled batch for training (post-processing). Expects `td` to be a TensorDict with batch size TxB. """ - obs = td['obs'] - action = td['action'][1:] - reward = td['reward'][1:].unsqueeze(-1) - task = td['task'][0] if 'task' in td.keys() else None - return self._to_device(obs, action, reward, task) + td = td.select("obs", "action", "reward", "task", strict=False).to(self._device, non_blocking=True) + obs = td.get('obs').contiguous() + action = td.get('action')[1:].contiguous() + reward = td.get('reward')[1:].unsqueeze(-1).contiguous() + task = td.get('task', None) + if task is not None: + task = task[0].contiguous() + return obs, action, reward, task def add(self, td): """Add an episode to the buffer.""" - td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps + td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64) if self._num_eps == 0: self._buffer = self._init(td) self._buffer.extend(td) diff --git a/tdmpc2/common/layers.py b/tdmpc2/common/layers.py index 1e0adb3..5890d8d 100644 --- a/tdmpc2/common/layers.py +++ b/tdmpc2/common/layers.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from functorch import combine_state_for_ensemble - +from tensordict import from_modules +from copy import deepcopy class Ensemble(nn.Module): """ @@ -11,14 +11,18 @@ class Ensemble(nn.Module): def __init__(self, modules, **kwargs): super().__init__() - modules = nn.ModuleList(modules) - fn, params, _ = combine_state_for_ensemble(modules) - self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs) - self.params = nn.ParameterList([nn.Parameter(p) for p in params]) + # combine_state_for_ensemble causes graph breaks + self.params = from_modules(*modules, as_module=True) + with self.params[0].data.to("meta").to_module(modules[0]): + self.module = deepcopy(modules[0]) self._repr = str(modules) + def _call(self, params, *args, **kwargs): + with params.to_module(self.module): + return self.module(*args, **kwargs) + def forward(self, *args, **kwargs): - return self.vmap([p for p in self.params], (), *args, **kwargs) + return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs) def __repr__(self): return 'Vectorized ' + self._repr @@ -32,13 +36,13 @@ class ShiftAug(nn.Module): def __init__(self, pad=3): super().__init__() self.pad = pad + self.padding = tuple([self.pad] * 4) def forward(self, x): x = x.float() n, _, h, w = x.size() assert h == w - padding = tuple([self.pad] * 4) - x = F.pad(x, padding, 'replicate') + x = F.pad(x, self.padding, 'replicate') eps = 1.0 / (h + 2 * self.pad) arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h] arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) @@ -59,7 +63,7 @@ class PixelPreprocess(nn.Module): super().__init__() def forward(self, x): - return x.div_(255.).sub_(0.5) + return x.div(255.).sub(0.5) class SimNorm(nn.Module): @@ -67,17 +71,17 @@ class SimNorm(nn.Module): Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616. """ - + def __init__(self, cfg): super().__init__() self.dim = cfg.simnorm_dim - + def forward(self, x): shp = x.shape x = x.view(*shp[:-1], -1, self.dim) x = F.softmax(x, dim=-1) return x.view(*shp) - + def __repr__(self): return f"SimNorm(dim={self.dim})" @@ -87,18 +91,20 @@ class NormedLinear(nn.Linear): Linear layer with LayerNorm, activation, and optionally dropout. """ - def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs): + def __init__(self, *args, dropout=0., act=None, **kwargs): super().__init__(*args, **kwargs) self.ln = nn.LayerNorm(self.out_features) + if act is None: + act = nn.Mish(inplace=False) self.act = act - self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None + self.dropout = nn.Dropout(dropout, inplace=False) if dropout else None def forward(self, x): x = super().forward(x) if self.dropout: x = self.dropout(x) return self.act(self.ln(x)) - + def __repr__(self): repr_dropout = f", dropout={self.dropout.p}" if self.dropout else "" return f"NormedLinear(in_features={self.in_features}, "\ @@ -130,9 +136,9 @@ def conv(in_shape, num_channels, act=None): assert in_shape[-1] == 64 # assumes rgb observations to be 64x64 layers = [ ShiftAug(), PixelPreprocess(), - nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=True), - nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=True), - nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=True), + nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=False), + nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=False), + nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=False), nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()] if act: layers.append(act) diff --git a/tdmpc2/common/logger.py b/tdmpc2/common/logger.py index 4dce7ca..b3047f9 100755 --- a/tdmpc2/common/logger.py +++ b/tdmpc2/common/logger.py @@ -1,11 +1,11 @@ +import dataclasses import os import datetime import re import numpy as np import pandas as pd from termcolor import colored -from omegaconf import OmegaConf - +from torchrl._utils import timeit from common import TASK_SET @@ -133,7 +133,7 @@ class Logger: group=self._group, tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"], dir=self._log_dir, - config=OmegaConf.to_container(cfg, resolve=True), + config=dataclasses.asdict(cfg), ) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) self._wandb = wandb @@ -238,3 +238,5 @@ class Logger: self._log_dir / "eval.csv", header=keys, index=None ) self._print(d, category) + timeit.print() + timeit.erase() diff --git a/tdmpc2/common/math.py b/tdmpc2/common/math.py index 62b8230..91fcdce 100644 --- a/tdmpc2/common/math.py +++ b/tdmpc2/common/math.py @@ -9,30 +9,30 @@ def soft_ce(pred, target, cfg): return -(target * pred).sum(-1, keepdim=True) -@torch.jit.script + def log_std(x, low, dif): return low + 0.5 * dif * (torch.tanh(x) + 1) -@torch.jit.script + def _gaussian_residual(eps, log_std): return -0.5 * eps.pow(2) - log_std -@torch.jit.script def _gaussian_logprob(residual): - return residual - 0.5 * torch.log(2 * torch.pi) + log2pi = 1.8378770351409912 + return residual - 0.5 * log2pi def gaussian_logprob(eps, log_std, size=None): """Compute Gaussian log probability.""" residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) if size is None: - size = eps.size(-1) + size = eps.shape[-1] return _gaussian_logprob(residual) * size -@torch.jit.script + def _squash(pi): return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) @@ -45,7 +45,7 @@ def squash(mu, pi, log_pi): return mu, pi, log_pi -@torch.jit.script + def symlog(x): """ Symmetric logarithmic function. @@ -54,7 +54,7 @@ def symlog(x): return torch.sign(x) * torch.log(1 + torch.abs(x)) -@torch.jit.script + def symexp(x): """ Symmetric exponential function. @@ -70,26 +70,32 @@ def two_hot(x, cfg): elif cfg.num_bins == 1: return symlog(x) x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1) - bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() - bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) - soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) - soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset) - soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) + bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size) + bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx).unsqueeze(-1) + soft_two_hot = torch.zeros(x.shape[0], cfg.num_bins, device=x.device, dtype=x.dtype) + bin_idx = bin_idx.long() + soft_two_hot = soft_two_hot.scatter(1, bin_idx.unsqueeze(1), 1 - bin_offset) + soft_two_hot = soft_two_hot.scatter(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) return soft_two_hot -DREG_BINS = None - - def two_hot_inv(x, cfg): """Converts a batch of soft two-hot encoded vectors to scalars.""" - global DREG_BINS if cfg.num_bins == 0: return x elif cfg.num_bins == 1: return symexp(x) - if DREG_BINS is None: - DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device) + dreg_bins = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device, dtype=x.dtype) x = F.softmax(x, dim=-1) - x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True) + x = torch.sum(x * dreg_bins, dim=-1, keepdim=True) return symexp(x) + +def gumbel_softmax_sample(p, temperature=1.0, dim=0): + logits = p.log() + # Generate Gumbel noise + gumbels = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + ) # ~Gumbel(0,1) + gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau) + y_soft = gumbels.softmax(dim) + return y_soft.argmax(-1) diff --git a/tdmpc2/common/scale.py b/tdmpc2/common/scale.py index 63f0bb2..a9f4654 100644 --- a/tdmpc2/common/scale.py +++ b/tdmpc2/common/scale.py @@ -1,48 +1,53 @@ import torch +from torch.nn import Buffer - -class RunningScale: +class RunningScale(torch.nn.Module): """Running trimmed scale estimator.""" def __init__(self, cfg): + super().__init__() self.cfg = cfg - self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda')) - self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')) + self._value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))) + self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))) def state_dict(self): return dict(value=self._value, percentiles=self._percentiles) def load_state_dict(self, state_dict): - self._value.data.copy_(state_dict['value']) - self._percentiles.data.copy_(state_dict['percentiles']) + self._value.copy_(state_dict['value']) + self._percentiles.copy_(state_dict['percentiles']) @property def value(self): - return self._value.cpu().item() + return self._value + + def _positions(self, x_shape): + positions = self._percentiles * (x_shape-1) / 100 + floored = torch.floor(positions) + ceiled = floored + 1 + ceiled = torch.where(ceiled > x_shape - 1, x_shape - 1, ceiled) + weight_ceiled = positions-floored + weight_floored = 1.0 - weight_ceiled + return floored.long(), ceiled.long(), weight_floored.unsqueeze(1), weight_ceiled.unsqueeze(1) def _percentile(self, x): x_dtype, x_shape = x.dtype, x.shape - x = x.view(x.shape[0], -1) - in_sorted, _ = torch.sort(x, dim=0) - positions = self._percentiles * (x.shape[0]-1) / 100 - floored = torch.floor(positions) - ceiled = floored + 1 - ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1 - weight_ceiled = positions-floored - weight_floored = 1.0 - weight_ceiled - d0 = in_sorted[floored.long(), :] * weight_floored[:, None] - d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None] - return (d0+d1).view(-1, *x_shape[1:]).type(x_dtype) + x = x.flatten(1, x.ndim-1) + in_sorted = torch.sort(x, dim=0).values + floored, ceiled, weight_floored, weight_ceiled = self._positions(x.shape[0]) + d0 = in_sorted[floored] * weight_floored + d1 = in_sorted[ceiled] * weight_ceiled + return (d0+d1).reshape(-1, *x_shape[1:]).to(x_dtype) def update(self, x): percentiles = self._percentile(x.detach()) value = torch.clamp(percentiles[1] - percentiles[0], min=1.) - self._value.data.lerp_(value, self.cfg.tau) + self._value.lerp_(value, self.cfg.tau) - def __call__(self, x, update=False): + def forward(self, x, update=False): if update: self.update(x) - return x * (1/self.value) + return x / self.value def __repr__(self): return f'RunningScale(S: {self.value})' diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index a780ad0..65572dd 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -5,7 +5,8 @@ import torch import torch.nn as nn from common import layers, math, init - +from tensordict import TensorDict +from tensordict.nn import TensorDictParams class WorldModel(nn.Module): """ @@ -18,7 +19,7 @@ class WorldModel(nn.Module): self.cfg = cfg if cfg.multitask: self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1) - self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim) + self.register_buffer("_action_masks", torch.zeros(len(cfg.tasks), cfg.action_dim)) for i in range(len(cfg.tasks)): self._action_masks[i, :cfg.action_dims[i]] = 1. self._encoder = layers.enc(cfg) @@ -27,26 +28,35 @@ class WorldModel(nn.Module): self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.apply(init.weight_init) - init.zero_([self._reward[-1].weight, self._Qs.params[-2]]) - self._target_Qs = deepcopy(self._Qs).requires_grad_(False) - self.log_std_min = torch.tensor(cfg.log_std_min) - self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min + init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]) + + self.register_buffer("log_std_min", torch.tensor(cfg.log_std_min)) + self.register_buffer("log_std_dif", torch.tensor(cfg.log_std_max) - self.log_std_min) + self.init() + + def init(self): + # Create params + self._detach_Qs_params = TensorDictParams(self._Qs.params.data, no_convert=True) + self._target_Qs_params = TensorDictParams(self._Qs.params.data.clone(), no_convert=True) + + # Create modules + with self._detach_Qs_params.data.to("meta").to_module(self._Qs.module): + self._detach_Qs = deepcopy(self._Qs) + self._target_Qs = deepcopy(self._Qs) + + # Assign params to modules + self._detach_Qs.params = self._detach_Qs_params + self._target_Qs.params = self._target_Qs_params @property def total_params(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) - + def to(self, *args, **kwargs): - """ - Overriding `to` method to also move additional tensors to device. - """ super().to(*args, **kwargs) - if self.cfg.multitask: - self._action_masks = self._action_masks.to(*args, **kwargs) - self.log_std_min = self.log_std_min.to(*args, **kwargs) - self.log_std_dif = self.log_std_dif.to(*args, **kwargs) + self.init() return self - + def train(self, mode=True): """ Overriding `train` method to keep target Q-networks in eval mode. @@ -55,26 +65,12 @@ class WorldModel(nn.Module): self._target_Qs.train(False) return self - def track_q_grad(self, mode=True): - """ - Enables/disables gradient tracking of Q-networks. - Avoids unnecessary computation during policy optimization. - This method also enables/disables gradients for task embeddings. - """ - for p in self._Qs.parameters(): - p.requires_grad_(mode) - if self.cfg.multitask: - for p in self._task_emb.parameters(): - p.requires_grad_(mode) - def soft_update_target_Q(self): """ Soft-update target Q-networks using Polyak averaging. """ - with torch.no_grad(): - for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()): - p_target.data.lerp_(p.data, self.cfg.tau) - + self._target_Qs_params.lerp_(self._detach_Qs_params, self.cfg.tau) + def task_emb(self, x, task): """ Continuous task embedding for multi-task experiments. @@ -109,7 +105,7 @@ class WorldModel(nn.Module): z = self.task_emb(z, task) z = torch.cat([z, a], dim=-1) return self._dynamics(z) - + def reward(self, z, a, task): """ Predicts instantaneous (single-step) reward. @@ -147,7 +143,7 @@ class WorldModel(nn.Module): return mu, pi, log_pi, log_std - def Q(self, z, a, task, return_type='min', target=False): + def Q(self, z, a, task, return_type='min', target=False, detach=False): """ Predict state-action value. `return_type` can be one of [`min`, `avg`, `all`]: @@ -160,13 +156,21 @@ class WorldModel(nn.Module): if self.cfg.multitask: z = self.task_emb(z, task) - + z = torch.cat([z, a], dim=-1) - out = (self._target_Qs if target else self._Qs)(z) + if target: + qnet = self._target_Qs + elif detach: + qnet = self._detach_Qs + else: + qnet = self._Qs + out = qnet(z) if return_type == 'all': return out - Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)] - Q1, Q2 = math.two_hot_inv(Q1, self.cfg), math.two_hot_inv(Q2, self.cfg) - return torch.min(Q1, Q2) if return_type == 'min' else (Q1 + Q2) / 2 + qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2] + Q = math.two_hot_inv(out[qidx], self.cfg) + if return_type == "min": + return Q.min(0).values + return Q.sum(0) / 2 diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index b720923..441e421 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -86,3 +86,7 @@ action_dims: ??? episode_lengths: ??? seed_steps: ??? bin_size: ??? + +# compile +compile: False +cudagraphs: False diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 9d49cf8..6ce3d7f 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -1,13 +1,18 @@ -import numpy as np import torch import torch.nn.functional as F +import functools +from torchrl._utils import timeit from common import math from common.scale import RunningScale from common.world_model import WorldModel +from tensordict.nn import CudaGraphModule +from tensordict import TensorDict + +CG_WARMUP = 1000 -class TDMPC2: +class TDMPC2(torch.nn.Module): """ TD-MPC2 agent. Implements training + inference. Can be used for both single-task and multi-task experiments, @@ -15,23 +20,71 @@ class TDMPC2: """ def __init__(self, cfg): + super().__init__() self.cfg = cfg - self.device = torch.device('cuda') + + self.device = torch.device('cuda:0') + self.model = WorldModel(cfg).to(self.device) + capturable = True self.optim = torch.optim.Adam([ {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._dynamics.parameters()}, {'params': self.model._reward.parameters()}, {'params': self.model._Qs.parameters()}, - {'params': self.model._task_emb.parameters() if self.cfg.multitask else []} - ], lr=self.cfg.lr) - self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5) + {'params': self.model._task_emb.parameters() if self.cfg.multitask else [] + } + ], lr=self.cfg.lr, capturable=capturable) + self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=capturable) self.model.eval() self.scale = RunningScale(cfg) self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.discount = torch.tensor( - [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda' + [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' ) if self.cfg.multitask else self._get_discount(cfg.episode_length) + self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) + + if cfg.compile: + mode = None if cfg.cudagraphs else "reduce-overhead" + print('compiling - update') + self._update = torch.compile(self._update, mode=mode) + + if cfg.cudagraphs: + print('cudagraphs - update') + self._update = CudaGraphModule(self._update, warmup=CG_WARMUP) + + @property + def plan(self): + _plan_val = getattr(self, "_plan_val", None) + if _plan_val is not None: + return _plan_val + if self.cfg.cudagraphs: + print('cudagraphs - plan') + self._plan_dict = { + (True, True): functools.partial(self._plan, t0=True, eval_mode=True), + (False, True): functools.partial(self._plan, t0=False, eval_mode=True), + (True, False): functools.partial(self._plan, t0=True, eval_mode=False), + (False, False): functools.partial(self._plan, t0=False, eval_mode=False), + } + if self.cfg.compile: + print('compiling - plan') + mode = None + self._plan_dict = {k: torch.compile(func, mode=mode) for k, func in self._plan_dict.items()} + + self._plan_dict = {k: CudaGraphModule(func, warmup=CG_WARMUP) for k, func in self._plan_dict.items()} + def plan(obs, t0=False, eval_mode=False, task=None): + if task is not None: + kwargs = {"task": task} + else: + kwargs = {} + torch.compiler.cudagraph_mark_step_begin() + return self._plan_dict[(t0, eval_mode)](obs=obs, **kwargs) + elif self.cfg.compile: + plan = torch.compile(self._plan, mode="reduce-overhead") + else: + plan = self._plan + self._plan_val = plan + return self._plan_val def _get_discount(self, episode_length): """ @@ -51,7 +104,7 @@ class TDMPC2: def save(self, fp): """ Save state dict of the agent to filepath. - + Args: fp (str): Filepath to save state dict to. """ @@ -60,7 +113,7 @@ class TDMPC2: def load(self, fp): """ Load a saved state dict from filepath (or dictionary) into current agent. - + Args: fp (str or dict): Filepath or state dict to load. """ @@ -71,23 +124,23 @@ class TDMPC2: def act(self, obs, t0=False, eval_mode=False, task=None): """ Select an action by planning in the latent space of the world model. - + Args: obs (torch.Tensor): Observation from the environment. t0 (bool): Whether this is the first observation in the episode. eval_mode (bool): Whether to use the mean of the action distribution. task (int): Task index (only used for multi-task experiments). - + Returns: torch.Tensor: Action to take in the environment. """ obs = obs.to(self.device, non_blocking=True).unsqueeze(0) if task is not None: task = torch.tensor([task], device=self.device) - z = self.model.encode(obs, task) if self.cfg.mpc: - a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) + a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) else: + z = self.model.encode(obs, task) a = self.model.pi(z, task)[int(not eval_mode)][0] return a.cpu() @@ -98,15 +151,16 @@ class TDMPC2: for t in range(self.cfg.horizon): reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) z = self.model.next(z, actions[t], task) - G += discount * reward - discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount + G = G + discount * reward + discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount + discount = discount * discount_update return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') @torch.no_grad() - def plan(self, z, t0=False, eval_mode=False, task=None): + def _plan(self, obs, t0=False, eval_mode=False, task=None): """ Plan a sequence of actions using the learned world model. - + Args: z (torch.Tensor): Latent state from which to plan. t0 (bool): Whether this is the first observation in the episode. @@ -115,8 +169,9 @@ class TDMPC2: Returns: torch.Tensor: Action to take in the environment. - """ + """ # Sample policy trajectories + z = self.model.encode(obs, task) if self.cfg.num_pi_trajs > 0: pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) _z = z.repeat(self.cfg.num_pi_trajs, 1) @@ -128,52 +183,53 @@ class TDMPC2: # Initialize state and parameters z = z.repeat(self.cfg.num_samples, 1) mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) - std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device) + std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) if not t0: mean[:-1] = self._prev_mean[1:] actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) if self.cfg.num_pi_trajs > 0: actions[:, :self.cfg.num_pi_trajs] = pi_actions - + # Iterate MPPI for _ in range(self.cfg.iterations): # Sample actions - actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \ - torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \ - .clamp(-1, 1) + r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) + actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r + actions_sample = actions_sample.clamp(-1, 1) + actions[:, self.cfg.num_pi_trajs:] = actions_sample if self.cfg.multitask: actions = actions * self.model._action_masks[task] # Compute elite actions - value = self._estimate_value(z, actions, task).nan_to_num_(0) + value = self._estimate_value(z, actions, task).nan_to_num(0) elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] # Update parameters - max_value = elite_value.max(0)[0] + max_value = elite_value.max(0).values score = torch.exp(self.cfg.temperature*(elite_value - max_value)) - score /= score.sum(0) - mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) - std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \ - .clamp_(self.cfg.min_std, self.cfg.max_std) + score = score / score.sum(0) + mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9) + std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt() + std = std.clamp(self.cfg.min_std, self.cfg.max_std) if self.cfg.multitask: mean = mean * self.model._action_masks[task] std = std * self.model._action_masks[task] # Select action - score = score.squeeze(1).cpu().numpy() - actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] - self._prev_mean = mean + rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs + actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) a, std = actions[0], std[0] if not eval_mode: - a += std * torch.randn(self.cfg.action_dim, device=std.device) - return a.clamp_(-1, 1) - + a = a + std * torch.randn(self.cfg.action_dim, device=std.device) + self._prev_mean.copy_(mean) + return a.clamp(-1, 1) + def update_pi(self, zs, task): """ Update policy using a sequence of latent states. - + Args: zs (torch.Tensor): Sequence of latent states. task (torch.Tensor): Task index (only used for multi-task experiments). @@ -181,10 +237,8 @@ class TDMPC2: Returns: float: Loss of the policy update. """ - self.pi_optim.zero_grad(set_to_none=True) - self.model.track_q_grad(False) _, pis, log_pis, _ = self.model.pi(zs, task) - qs = self.model.Q(zs, pis, task, return_type='avg') + qs = self.model.Q(zs, pis, task, return_type='avg', detach=True) self.scale.update(qs[0]) qs = self.scale(qs) @@ -192,22 +246,23 @@ class TDMPC2: rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() pi_loss.backward() - torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) + pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) self.pi_optim.step() - self.model.track_q_grad(True) + # For some reason, cudagraph prefers to see the zero grad after step + self.pi_optim.zero_grad(set_to_none=True) - return pi_loss.item() + return pi_loss.detach(), pi_grad_norm @torch.no_grad() def _td_target(self, next_z, reward, task): """ Compute the TD-target from a reward and the observation at the following time step. - + Args: next_z (torch.Tensor): Latent state at the following time step. reward (torch.Tensor): Reward at the current time step. task (torch.Tensor): Task index (only used for multi-task experiments). - + Returns: torch.Tensor: TD-target. """ @@ -218,22 +273,28 @@ class TDMPC2: def update(self, buffer): """ Main update function. Corresponds to one iteration of model learning. - + Args: buffer (common.buffer.Buffer): Replay buffer. - + Returns: dict: Dictionary of training statistics. """ - obs, action, reward, task = buffer.sample() - + with timeit("sample"): + obs, action, reward, task = buffer.sample() + kwargs = {} + if task is not None: + kwargs["task"] = task + torch.compiler.cudagraph_mark_step_begin() + return self._update(obs, action, reward, **kwargs) + + def _update(self, obs, action, reward, task=None): # Compute targets with torch.no_grad(): next_z = self.model.encode(obs[1:], task) td_targets = self._td_target(next_z, reward, task) # Prepare for update - self.optim.zero_grad(set_to_none=True) self.model.train() # Latent rollout @@ -241,25 +302,26 @@ class TDMPC2: z = self.model.encode(obs[0], task) zs[0] = z consistency_loss = 0 - for t in range(self.cfg.horizon): - z = self.model.next(z, action[t], task) - consistency_loss += F.mse_loss(z, next_z[t]) * self.cfg.rho**t + for t, (_action, _next_z) in enumerate(zip(action.unbind(0), next_z.unbind(0))): + z = self.model.next(z, _action, task) + consistency_loss = consistency_loss + F.mse_loss(z, _next_z) * self.cfg.rho**t zs[t+1] = z # Predictions _zs = zs[:-1] qs = self.model.Q(_zs, action, task, return_type='all') reward_preds = self.model.reward(_zs, action, task) - + # Compute losses reward_loss, value_loss = 0, 0 - for t in range(self.cfg.horizon): - reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t - for q in range(self.cfg.num_q): - value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t - consistency_loss *= (1/self.cfg.horizon) - reward_loss *= (1/self.cfg.horizon) - value_loss *= (1/(self.cfg.horizon * self.cfg.num_q)) + for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))): + reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t + for q, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)): + value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t + + consistency_loss = consistency_loss / self.cfg.horizon + reward_loss = reward_loss / self.cfg.horizon + value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q) total_loss = ( self.cfg.consistency_coef * consistency_loss + self.cfg.reward_coef * reward_loss + @@ -270,21 +332,23 @@ class TDMPC2: total_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm) self.optim.step() + self.optim.zero_grad(set_to_none=True) # Update policy - pi_loss = self.update_pi(zs.detach(), task) + pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task) # Update target Q-functions self.model.soft_update_target_Q() # Return training statistics self.model.eval() - return { - "consistency_loss": float(consistency_loss.mean().item()), - "reward_loss": float(reward_loss.mean().item()), - "value_loss": float(value_loss.mean().item()), + return TensorDict({ + "consistency_loss": consistency_loss, + "reward_loss": reward_loss, + "value_loss": value_loss, "pi_loss": pi_loss, - "total_loss": float(total_loss.mean().item()), - "grad_norm": float(grad_norm), - "pi_scale": float(self.scale.value), - } + "total_loss": total_loss, + "grad_norm": grad_norm, + "pi_grad_norm": pi_grad_norm, + "pi_scale": self.scale.value, + }).detach().mean() diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 5953bb2..6afb648 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -1,6 +1,8 @@ import os os.environ['MUJOCO_GL'] = 'egl' os.environ['LAZY_LEGACY_OP'] = '0' +os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1" +os.environ['TORCH_LOGS'] = "+recompiles" import warnings warnings.filterwarnings('ignore') import torch @@ -16,10 +18,27 @@ from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer from trainer.online_trainer import OnlineTrainer from common.logger import Logger - +import dataclasses +from typing import Any +from omegaconf import OmegaConf torch.backends.cudnn.benchmark = True +def cfg_to_dataclass(cfg, frozen=False): + # Converts an OmegaConf config to a dataclass, which will not cause graph breaks + cfg_dict = OmegaConf.to_container(cfg) + fields = [] + for key, value in cfg_dict.items(): + fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_))) + + # Create the dataclass + dataclass_name = "Config" + dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen) + def get(self, val, default=None): + return getattr(self, val, default) + dataclass.get = get + return dataclass() + @hydra.main(config_name='config', config_path='.') def train(cfg: dict): """ @@ -47,6 +66,9 @@ def train(cfg: dict): print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer + + cfg = cfg_to_dataclass(cfg) + trainer = trainer_cls( cfg=cfg, env=make_env(cfg), diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index a3326bc..f3072b5 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -3,7 +3,7 @@ from time import time import numpy as np import torch from tensordict.tensordict import TensorDict - +from torchrl._utils import timeit from trainer.base import Trainer @@ -57,61 +57,64 @@ class OnlineTrainer(Trainer): action = torch.full_like(self.env.rand_act(), float('nan')) if reward is None: reward = torch.tensor(float('nan')) - td = TensorDict(dict( + td = TensorDict( obs=obs, action=action.unsqueeze(0), reward=reward.unsqueeze(0), - ), batch_size=(1,)) + batch_size=(1,)) return td def train(self): """Train a TD-MPC2 agent.""" - train_metrics, done, eval_next = {}, True, True + train_metrics, done, eval_next = {}, True, False while self._step <= self.cfg.steps: + with timeit("global-step"): + # Evaluate agent periodically + if self._step > 0 and self._step % self.cfg.eval_freq == 0: + eval_next = True - # Evaluate agent periodically - if self._step % self.cfg.eval_freq == 0: - eval_next = True + # Reset environment + if done or (self._step == self.cfg.seed_steps + 1): + if eval_next: + eval_metrics = self.eval() + eval_metrics.update(self.common_metrics()) + self.logger.log(eval_metrics, 'eval') + eval_next = False - # Reset environment - if done: - if eval_next: - eval_metrics = self.eval() - eval_metrics.update(self.common_metrics()) - self.logger.log(eval_metrics, 'eval') - eval_next = False + if self._step > 0: + train_metrics.update( + episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), + episode_success=info['success'], + ) + train_metrics.update(self.common_metrics()) + train_metrics.update(timeit.todict()) + self.logger.log(train_metrics, 'train') + self._ep_idx = self.buffer.add(torch.cat(self._tds)) - if self._step > 0: - train_metrics.update( - episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), - episode_success=info['success'], - ) - train_metrics.update(self.common_metrics()) - self.logger.log(train_metrics, 'train') - self._ep_idx = self.buffer.add(torch.cat(self._tds)) + obs = self.env.reset() + self._tds = [self.to_td(obs)] - obs = self.env.reset() - self._tds = [self.to_td(obs)] + # Collect experience + with timeit("act"): + if self._step > self.cfg.seed_steps: + action = self.agent.act(obs, t0=len(self._tds)==1) + else: + action = self.env.rand_act() + obs, reward, done, info = self.env.step(action) + self._tds.append(self.to_td(obs, action, reward)) - # Collect experience - if self._step > self.cfg.seed_steps: - action = self.agent.act(obs, t0=len(self._tds)==1) - else: - action = self.env.rand_act() - obs, reward, done, info = self.env.step(action) - self._tds.append(self.to_td(obs, action, reward)) + # Update agent + if self._step >= self.cfg.seed_steps: + if self._step == self.cfg.seed_steps: + num_updates = self.cfg.seed_steps + print('Pretraining agent on seed data...') + else: + num_updates = 1 + for _ in range(num_updates): + with timeit("update"): + _train_metrics = self.agent.update(self.buffer) + train_metrics.update(_train_metrics) - # Update agent - if self._step >= self.cfg.seed_steps: - if self._step == self.cfg.seed_steps: - num_updates = self.cfg.seed_steps - print('Pretraining agent on seed data...') - else: - num_updates = 1 - for _ in range(num_updates): - _train_metrics = self.agent.update(self.buffer) - train_metrics.update(_train_metrics) + self._step += 1 - self._step += 1 - self.logger.finish(self.agent) From 8b731819a67105107d47882cb90a6654f933b427 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 25 Sep 2024 07:57:26 -0700 Subject: [PATCH 02/12] merged commits --- docker/environment.yaml | 62 ++++------ requirements.txt | 34 ++++++ tdmpc2/common/buffer.py | 34 +++--- tdmpc2/common/layers.py | 44 ++++--- tdmpc2/common/logger.py | 8 +- tdmpc2/common/math.py | 46 ++++--- tdmpc2/common/scale.py | 49 ++++---- tdmpc2/common/world_model.py | 80 ++++++------ tdmpc2/config.yaml | 4 + tdmpc2/tdmpc2.py | 202 ++++++++++++++++++++----------- tdmpc2/train.py | 25 +++- tdmpc2/trainer/online_trainer.py | 89 +++++++------- 12 files changed, 406 insertions(+), 271 deletions(-) create mode 100644 requirements.txt diff --git a/docker/environment.yaml b/docker/environment.yaml index 3081b8d..857c81a 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -1,56 +1,46 @@ -name: tdmpc2 +name: graph channels: - pytorch-nightly - nvidia - conda-forge - defaults dependencies: - - cudatoolkit=11.7 - - glew=2.1.0 - - glib=2.68.4 - - pip=21.0 - - python=3.9.0 - - pytorch>=2.2.2 - - torchvision>=0.16.2 + - glew=2.2.0 + - glib=2.78.4 + - pip=24.0 + - python=3.9 + - pytorch + - pytorch-cuda=12.4 + - torchvision - pip: - - absl-py==2.0.0 - - "cython<3" + - absl-py==2.1.0 - dm-control==1.0.8 + - glfw==2.7.0 - ffmpeg==1.4 - - glfw==2.6.4 + - imageio==2.34.1 + - imageio-ffmpeg==0.4.9 + - h5py==3.11.0 - hydra-core==1.3.2 - hydra-submitit-launcher==1.2.0 - - imageio==2.33.1 - - imageio-ffmpeg==0.4.9 - - kornia==0.7.1 + - submitit==1.5.1 + - omegaconf==2.3.0 - moviepy==1.0.3 - mujoco==2.3.1 - - mujoco-py==2.1.2.14 - - numpy==1.23.5 - - omegaconf==2.3.0 - - open3d==0.18.0 - - opencv-contrib-python==4.9.0.80 - - opencv-python==4.9.0.80 - - pandas==2.1.4 - - sapien==2.2.1 - - submitit==1.5.1 - - setuptools==65.5.0 - - patchelf==0.17.2.1 - - protobuf==4.25.2 - - pillow==10.2.0 - - pyquaternion==0.9.9 - - tensordict-nightly==2024.3.26 + - numpy==1.24.4 + - tensordict-nightly + - torchrl-nightly + - kornia==0.7.2 - termcolor==2.4.0 - - torchrl-nightly==2024.3.26 - - transforms3d==0.4.1 - - trimesh==4.0.9 - - tqdm==4.66.1 - - wandb==0.16.2 - - wheel==0.38.0 + - tqdm==4.66.4 + - pandas==2.0.3 + - wandb==0.17.4 + - matplotlib==3.7.5 + - seaborn==0.13.2 + - gpustat==1.1.1 #################### # Gym: # (unmaintained but required for maniskill2/meta-world/myosuite) - # - gym==0.21.0 + - gym==0.21.0 #################### # ManiSkill2: # (requires gym==0.21.0 which occasionally breaks) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f403b3e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,34 @@ +absl-py +cython +dm-control +ffmpeg +glfw +hydra-core +hydra-submitit-launcher +imageio +imageio-ffmpeg +kornia +moviepy +mujoco +mujoco-py +numpy<2 +omegaconf +open3d +opencv-contrib-python +opencv-python +pandas +sapien +submitit +setuptools +patchelf +protobuf +pillow +pyquaternion +tensordict-nightly +termcolor +torchrl-nightly +transforms3d +trimesh +tqdm +wandb +wheel diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index e82f2fe..3ff5b28 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -12,7 +12,7 @@ class Buffer(): def __init__(self, cfg): self.cfg = cfg - self._device = torch.device('cuda') + self._device = torch.device('cuda:0') self._capacity = min(cfg.buffer_size, cfg.steps) self._sampler = SliceSampler( num_slices=self.cfg.batch_size, @@ -28,7 +28,7 @@ class Buffer(): def capacity(self): """Return the capacity of the buffer.""" return self._capacity - + @property def num_eps(self): """Return the number of episodes in the buffer.""" @@ -41,8 +41,8 @@ class Buffer(): return ReplayBuffer( storage=storage, sampler=self._sampler, - pin_memory=True, - prefetch=1, + pin_memory=False, + prefetch=0, batch_size=self._batch_size, ) @@ -58,32 +58,30 @@ class Buffer(): total_bytes = bytes_per_step*self._capacity print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory - storage_device = 'cuda' if 2.5*total_bytes < mem_free else 'cpu' + storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu' print(f'Using {storage_device.upper()} memory for storage.') + self._storage_device = torch.device(storage_device) return self._reserve_buffer( - LazyTensorStorage(self._capacity, device=torch.device(storage_device)) + LazyTensorStorage(self._capacity, device=self._storage_device) ) - def _to_device(self, *args, device=None): - if device is None: - device = self._device - return (arg.to(device, non_blocking=True) \ - if arg is not None else None for arg in args) - def _prepare_batch(self, td): """ Prepare a sampled batch for training (post-processing). Expects `td` to be a TensorDict with batch size TxB. """ - obs = td['obs'] - action = td['action'][1:] - reward = td['reward'][1:].unsqueeze(-1) - task = td['task'][0] if 'task' in td.keys() else None - return self._to_device(obs, action, reward, task) + td = td.select("obs", "action", "reward", "task", strict=False).to(self._device, non_blocking=True) + obs = td.get('obs').contiguous() + action = td.get('action')[1:].contiguous() + reward = td.get('reward')[1:].unsqueeze(-1).contiguous() + task = td.get('task', None) + if task is not None: + task = task[0].contiguous() + return obs, action, reward, task def add(self, td): """Add an episode to the buffer.""" - td['episode'] = torch.ones_like(td['reward'], dtype=torch.int64) * self._num_eps + td['episode'] = torch.full_like(td['reward'], self._num_eps, dtype=torch.int64) if self._num_eps == 0: self._buffer = self._init(td) self._buffer.extend(td) diff --git a/tdmpc2/common/layers.py b/tdmpc2/common/layers.py index 1e0adb3..5890d8d 100644 --- a/tdmpc2/common/layers.py +++ b/tdmpc2/common/layers.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from functorch import combine_state_for_ensemble - +from tensordict import from_modules +from copy import deepcopy class Ensemble(nn.Module): """ @@ -11,14 +11,18 @@ class Ensemble(nn.Module): def __init__(self, modules, **kwargs): super().__init__() - modules = nn.ModuleList(modules) - fn, params, _ = combine_state_for_ensemble(modules) - self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness='different', **kwargs) - self.params = nn.ParameterList([nn.Parameter(p) for p in params]) + # combine_state_for_ensemble causes graph breaks + self.params = from_modules(*modules, as_module=True) + with self.params[0].data.to("meta").to_module(modules[0]): + self.module = deepcopy(modules[0]) self._repr = str(modules) + def _call(self, params, *args, **kwargs): + with params.to_module(self.module): + return self.module(*args, **kwargs) + def forward(self, *args, **kwargs): - return self.vmap([p for p in self.params], (), *args, **kwargs) + return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs) def __repr__(self): return 'Vectorized ' + self._repr @@ -32,13 +36,13 @@ class ShiftAug(nn.Module): def __init__(self, pad=3): super().__init__() self.pad = pad + self.padding = tuple([self.pad] * 4) def forward(self, x): x = x.float() n, _, h, w = x.size() assert h == w - padding = tuple([self.pad] * 4) - x = F.pad(x, padding, 'replicate') + x = F.pad(x, self.padding, 'replicate') eps = 1.0 / (h + 2 * self.pad) arange = torch.linspace(-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype)[:h] arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) @@ -59,7 +63,7 @@ class PixelPreprocess(nn.Module): super().__init__() def forward(self, x): - return x.div_(255.).sub_(0.5) + return x.div(255.).sub(0.5) class SimNorm(nn.Module): @@ -67,17 +71,17 @@ class SimNorm(nn.Module): Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616. """ - + def __init__(self, cfg): super().__init__() self.dim = cfg.simnorm_dim - + def forward(self, x): shp = x.shape x = x.view(*shp[:-1], -1, self.dim) x = F.softmax(x, dim=-1) return x.view(*shp) - + def __repr__(self): return f"SimNorm(dim={self.dim})" @@ -87,18 +91,20 @@ class NormedLinear(nn.Linear): Linear layer with LayerNorm, activation, and optionally dropout. """ - def __init__(self, *args, dropout=0., act=nn.Mish(inplace=True), **kwargs): + def __init__(self, *args, dropout=0., act=None, **kwargs): super().__init__(*args, **kwargs) self.ln = nn.LayerNorm(self.out_features) + if act is None: + act = nn.Mish(inplace=False) self.act = act - self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None + self.dropout = nn.Dropout(dropout, inplace=False) if dropout else None def forward(self, x): x = super().forward(x) if self.dropout: x = self.dropout(x) return self.act(self.ln(x)) - + def __repr__(self): repr_dropout = f", dropout={self.dropout.p}" if self.dropout else "" return f"NormedLinear(in_features={self.in_features}, "\ @@ -130,9 +136,9 @@ def conv(in_shape, num_channels, act=None): assert in_shape[-1] == 64 # assumes rgb observations to be 64x64 layers = [ ShiftAug(), PixelPreprocess(), - nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=True), - nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=True), - nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=True), + nn.Conv2d(in_shape[0], num_channels, 7, stride=2), nn.ReLU(inplace=False), + nn.Conv2d(num_channels, num_channels, 5, stride=2), nn.ReLU(inplace=False), + nn.Conv2d(num_channels, num_channels, 3, stride=2), nn.ReLU(inplace=False), nn.Conv2d(num_channels, num_channels, 3, stride=1), nn.Flatten()] if act: layers.append(act) diff --git a/tdmpc2/common/logger.py b/tdmpc2/common/logger.py index 4dce7ca..b3047f9 100755 --- a/tdmpc2/common/logger.py +++ b/tdmpc2/common/logger.py @@ -1,11 +1,11 @@ +import dataclasses import os import datetime import re import numpy as np import pandas as pd from termcolor import colored -from omegaconf import OmegaConf - +from torchrl._utils import timeit from common import TASK_SET @@ -133,7 +133,7 @@ class Logger: group=self._group, tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"], dir=self._log_dir, - config=OmegaConf.to_container(cfg, resolve=True), + config=dataclasses.asdict(cfg), ) print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) self._wandb = wandb @@ -238,3 +238,5 @@ class Logger: self._log_dir / "eval.csv", header=keys, index=None ) self._print(d, category) + timeit.print() + timeit.erase() diff --git a/tdmpc2/common/math.py b/tdmpc2/common/math.py index 62b8230..91fcdce 100644 --- a/tdmpc2/common/math.py +++ b/tdmpc2/common/math.py @@ -9,30 +9,30 @@ def soft_ce(pred, target, cfg): return -(target * pred).sum(-1, keepdim=True) -@torch.jit.script + def log_std(x, low, dif): return low + 0.5 * dif * (torch.tanh(x) + 1) -@torch.jit.script + def _gaussian_residual(eps, log_std): return -0.5 * eps.pow(2) - log_std -@torch.jit.script def _gaussian_logprob(residual): - return residual - 0.5 * torch.log(2 * torch.pi) + log2pi = 1.8378770351409912 + return residual - 0.5 * log2pi def gaussian_logprob(eps, log_std, size=None): """Compute Gaussian log probability.""" residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) if size is None: - size = eps.size(-1) + size = eps.shape[-1] return _gaussian_logprob(residual) * size -@torch.jit.script + def _squash(pi): return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) @@ -45,7 +45,7 @@ def squash(mu, pi, log_pi): return mu, pi, log_pi -@torch.jit.script + def symlog(x): """ Symmetric logarithmic function. @@ -54,7 +54,7 @@ def symlog(x): return torch.sign(x) * torch.log(1 + torch.abs(x)) -@torch.jit.script + def symexp(x): """ Symmetric exponential function. @@ -70,26 +70,32 @@ def two_hot(x, cfg): elif cfg.num_bins == 1: return symlog(x) x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax).squeeze(1) - bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() - bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) - soft_two_hot = torch.zeros(x.size(0), cfg.num_bins, device=x.device) - soft_two_hot.scatter_(1, bin_idx.unsqueeze(1), 1 - bin_offset) - soft_two_hot.scatter_(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) + bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size) + bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx).unsqueeze(-1) + soft_two_hot = torch.zeros(x.shape[0], cfg.num_bins, device=x.device, dtype=x.dtype) + bin_idx = bin_idx.long() + soft_two_hot = soft_two_hot.scatter(1, bin_idx.unsqueeze(1), 1 - bin_offset) + soft_two_hot = soft_two_hot.scatter(1, (bin_idx.unsqueeze(1) + 1) % cfg.num_bins, bin_offset) return soft_two_hot -DREG_BINS = None - - def two_hot_inv(x, cfg): """Converts a batch of soft two-hot encoded vectors to scalars.""" - global DREG_BINS if cfg.num_bins == 0: return x elif cfg.num_bins == 1: return symexp(x) - if DREG_BINS is None: - DREG_BINS = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device) + dreg_bins = torch.linspace(cfg.vmin, cfg.vmax, cfg.num_bins, device=x.device, dtype=x.dtype) x = F.softmax(x, dim=-1) - x = torch.sum(x * DREG_BINS, dim=-1, keepdim=True) + x = torch.sum(x * dreg_bins, dim=-1, keepdim=True) return symexp(x) + +def gumbel_softmax_sample(p, temperature=1.0, dim=0): + logits = p.log() + # Generate Gumbel noise + gumbels = ( + -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() + ) # ~Gumbel(0,1) + gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau) + y_soft = gumbels.softmax(dim) + return y_soft.argmax(-1) diff --git a/tdmpc2/common/scale.py b/tdmpc2/common/scale.py index 63f0bb2..8fd1740 100644 --- a/tdmpc2/common/scale.py +++ b/tdmpc2/common/scale.py @@ -1,48 +1,49 @@ import torch +from torch.nn import Buffer - -class RunningScale: +class RunningScale(torch.nn.Module): """Running trimmed scale estimator.""" def __init__(self, cfg): + super().__init__() self.cfg = cfg - self._value = torch.ones(1, dtype=torch.float32, device=torch.device('cuda')) - self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda')) + self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))) + self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))) def state_dict(self): - return dict(value=self._value, percentiles=self._percentiles) + return dict(value=self.value, percentiles=self._percentiles) def load_state_dict(self, state_dict): - self._value.data.copy_(state_dict['value']) - self._percentiles.data.copy_(state_dict['percentiles']) + self.value.copy_(state_dict['value']) + self._percentiles.copy_(state_dict['percentiles']) - @property - def value(self): - return self._value.cpu().item() + def _positions(self, x_shape): + positions = self._percentiles * (x_shape-1) / 100 + floored = torch.floor(positions) + ceiled = floored + 1 + ceiled = torch.where(ceiled > x_shape - 1, x_shape - 1, ceiled) + weight_ceiled = positions-floored + weight_floored = 1.0 - weight_ceiled + return floored.long(), ceiled.long(), weight_floored.unsqueeze(1), weight_ceiled.unsqueeze(1) def _percentile(self, x): x_dtype, x_shape = x.dtype, x.shape - x = x.view(x.shape[0], -1) - in_sorted, _ = torch.sort(x, dim=0) - positions = self._percentiles * (x.shape[0]-1) / 100 - floored = torch.floor(positions) - ceiled = floored + 1 - ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1 - weight_ceiled = positions-floored - weight_floored = 1.0 - weight_ceiled - d0 = in_sorted[floored.long(), :] * weight_floored[:, None] - d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None] - return (d0+d1).view(-1, *x_shape[1:]).type(x_dtype) + x = x.flatten(1, x.ndim-1) + in_sorted = torch.sort(x, dim=0).values + floored, ceiled, weight_floored, weight_ceiled = self._positions(x.shape[0]) + d0 = in_sorted[floored] * weight_floored + d1 = in_sorted[ceiled] * weight_ceiled + return (d0+d1).reshape(-1, *x_shape[1:]).to(x_dtype) def update(self, x): percentiles = self._percentile(x.detach()) value = torch.clamp(percentiles[1] - percentiles[0], min=1.) - self._value.data.lerp_(value, self.cfg.tau) + self.value.data.lerp_(value, self.cfg.tau) - def __call__(self, x, update=False): + def forward(self, x, update=False): if update: self.update(x) - return x * (1/self.value) + return x / self.value def __repr__(self): return f'RunningScale(S: {self.value})' diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index a780ad0..65572dd 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -5,7 +5,8 @@ import torch import torch.nn as nn from common import layers, math, init - +from tensordict import TensorDict +from tensordict.nn import TensorDictParams class WorldModel(nn.Module): """ @@ -18,7 +19,7 @@ class WorldModel(nn.Module): self.cfg = cfg if cfg.multitask: self._task_emb = nn.Embedding(len(cfg.tasks), cfg.task_dim, max_norm=1) - self._action_masks = torch.zeros(len(cfg.tasks), cfg.action_dim) + self.register_buffer("_action_masks", torch.zeros(len(cfg.tasks), cfg.action_dim)) for i in range(len(cfg.tasks)): self._action_masks[i, :cfg.action_dims[i]] = 1. self._encoder = layers.enc(cfg) @@ -27,26 +28,35 @@ class WorldModel(nn.Module): self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.apply(init.weight_init) - init.zero_([self._reward[-1].weight, self._Qs.params[-2]]) - self._target_Qs = deepcopy(self._Qs).requires_grad_(False) - self.log_std_min = torch.tensor(cfg.log_std_min) - self.log_std_dif = torch.tensor(cfg.log_std_max) - self.log_std_min + init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]) + + self.register_buffer("log_std_min", torch.tensor(cfg.log_std_min)) + self.register_buffer("log_std_dif", torch.tensor(cfg.log_std_max) - self.log_std_min) + self.init() + + def init(self): + # Create params + self._detach_Qs_params = TensorDictParams(self._Qs.params.data, no_convert=True) + self._target_Qs_params = TensorDictParams(self._Qs.params.data.clone(), no_convert=True) + + # Create modules + with self._detach_Qs_params.data.to("meta").to_module(self._Qs.module): + self._detach_Qs = deepcopy(self._Qs) + self._target_Qs = deepcopy(self._Qs) + + # Assign params to modules + self._detach_Qs.params = self._detach_Qs_params + self._target_Qs.params = self._target_Qs_params @property def total_params(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) - + def to(self, *args, **kwargs): - """ - Overriding `to` method to also move additional tensors to device. - """ super().to(*args, **kwargs) - if self.cfg.multitask: - self._action_masks = self._action_masks.to(*args, **kwargs) - self.log_std_min = self.log_std_min.to(*args, **kwargs) - self.log_std_dif = self.log_std_dif.to(*args, **kwargs) + self.init() return self - + def train(self, mode=True): """ Overriding `train` method to keep target Q-networks in eval mode. @@ -55,26 +65,12 @@ class WorldModel(nn.Module): self._target_Qs.train(False) return self - def track_q_grad(self, mode=True): - """ - Enables/disables gradient tracking of Q-networks. - Avoids unnecessary computation during policy optimization. - This method also enables/disables gradients for task embeddings. - """ - for p in self._Qs.parameters(): - p.requires_grad_(mode) - if self.cfg.multitask: - for p in self._task_emb.parameters(): - p.requires_grad_(mode) - def soft_update_target_Q(self): """ Soft-update target Q-networks using Polyak averaging. """ - with torch.no_grad(): - for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters()): - p_target.data.lerp_(p.data, self.cfg.tau) - + self._target_Qs_params.lerp_(self._detach_Qs_params, self.cfg.tau) + def task_emb(self, x, task): """ Continuous task embedding for multi-task experiments. @@ -109,7 +105,7 @@ class WorldModel(nn.Module): z = self.task_emb(z, task) z = torch.cat([z, a], dim=-1) return self._dynamics(z) - + def reward(self, z, a, task): """ Predicts instantaneous (single-step) reward. @@ -147,7 +143,7 @@ class WorldModel(nn.Module): return mu, pi, log_pi, log_std - def Q(self, z, a, task, return_type='min', target=False): + def Q(self, z, a, task, return_type='min', target=False, detach=False): """ Predict state-action value. `return_type` can be one of [`min`, `avg`, `all`]: @@ -160,13 +156,21 @@ class WorldModel(nn.Module): if self.cfg.multitask: z = self.task_emb(z, task) - + z = torch.cat([z, a], dim=-1) - out = (self._target_Qs if target else self._Qs)(z) + if target: + qnet = self._target_Qs + elif detach: + qnet = self._detach_Qs + else: + qnet = self._Qs + out = qnet(z) if return_type == 'all': return out - Q1, Q2 = out[np.random.choice(self.cfg.num_q, 2, replace=False)] - Q1, Q2 = math.two_hot_inv(Q1, self.cfg), math.two_hot_inv(Q2, self.cfg) - return torch.min(Q1, Q2) if return_type == 'min' else (Q1 + Q2) / 2 + qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2] + Q = math.two_hot_inv(out[qidx], self.cfg) + if return_type == "min": + return Q.min(0).values + return Q.sum(0) / 2 diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index b720923..441e421 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -86,3 +86,7 @@ action_dims: ??? episode_lengths: ??? seed_steps: ??? bin_size: ??? + +# compile +compile: False +cudagraphs: False diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 9d49cf8..6ce3d7f 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -1,13 +1,18 @@ -import numpy as np import torch import torch.nn.functional as F +import functools +from torchrl._utils import timeit from common import math from common.scale import RunningScale from common.world_model import WorldModel +from tensordict.nn import CudaGraphModule +from tensordict import TensorDict + +CG_WARMUP = 1000 -class TDMPC2: +class TDMPC2(torch.nn.Module): """ TD-MPC2 agent. Implements training + inference. Can be used for both single-task and multi-task experiments, @@ -15,23 +20,71 @@ class TDMPC2: """ def __init__(self, cfg): + super().__init__() self.cfg = cfg - self.device = torch.device('cuda') + + self.device = torch.device('cuda:0') + self.model = WorldModel(cfg).to(self.device) + capturable = True self.optim = torch.optim.Adam([ {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._dynamics.parameters()}, {'params': self.model._reward.parameters()}, {'params': self.model._Qs.parameters()}, - {'params': self.model._task_emb.parameters() if self.cfg.multitask else []} - ], lr=self.cfg.lr) - self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5) + {'params': self.model._task_emb.parameters() if self.cfg.multitask else [] + } + ], lr=self.cfg.lr, capturable=capturable) + self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=capturable) self.model.eval() self.scale = RunningScale(cfg) self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.discount = torch.tensor( - [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda' + [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' ) if self.cfg.multitask else self._get_discount(cfg.episode_length) + self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) + + if cfg.compile: + mode = None if cfg.cudagraphs else "reduce-overhead" + print('compiling - update') + self._update = torch.compile(self._update, mode=mode) + + if cfg.cudagraphs: + print('cudagraphs - update') + self._update = CudaGraphModule(self._update, warmup=CG_WARMUP) + + @property + def plan(self): + _plan_val = getattr(self, "_plan_val", None) + if _plan_val is not None: + return _plan_val + if self.cfg.cudagraphs: + print('cudagraphs - plan') + self._plan_dict = { + (True, True): functools.partial(self._plan, t0=True, eval_mode=True), + (False, True): functools.partial(self._plan, t0=False, eval_mode=True), + (True, False): functools.partial(self._plan, t0=True, eval_mode=False), + (False, False): functools.partial(self._plan, t0=False, eval_mode=False), + } + if self.cfg.compile: + print('compiling - plan') + mode = None + self._plan_dict = {k: torch.compile(func, mode=mode) for k, func in self._plan_dict.items()} + + self._plan_dict = {k: CudaGraphModule(func, warmup=CG_WARMUP) for k, func in self._plan_dict.items()} + def plan(obs, t0=False, eval_mode=False, task=None): + if task is not None: + kwargs = {"task": task} + else: + kwargs = {} + torch.compiler.cudagraph_mark_step_begin() + return self._plan_dict[(t0, eval_mode)](obs=obs, **kwargs) + elif self.cfg.compile: + plan = torch.compile(self._plan, mode="reduce-overhead") + else: + plan = self._plan + self._plan_val = plan + return self._plan_val def _get_discount(self, episode_length): """ @@ -51,7 +104,7 @@ class TDMPC2: def save(self, fp): """ Save state dict of the agent to filepath. - + Args: fp (str): Filepath to save state dict to. """ @@ -60,7 +113,7 @@ class TDMPC2: def load(self, fp): """ Load a saved state dict from filepath (or dictionary) into current agent. - + Args: fp (str or dict): Filepath or state dict to load. """ @@ -71,23 +124,23 @@ class TDMPC2: def act(self, obs, t0=False, eval_mode=False, task=None): """ Select an action by planning in the latent space of the world model. - + Args: obs (torch.Tensor): Observation from the environment. t0 (bool): Whether this is the first observation in the episode. eval_mode (bool): Whether to use the mean of the action distribution. task (int): Task index (only used for multi-task experiments). - + Returns: torch.Tensor: Action to take in the environment. """ obs = obs.to(self.device, non_blocking=True).unsqueeze(0) if task is not None: task = torch.tensor([task], device=self.device) - z = self.model.encode(obs, task) if self.cfg.mpc: - a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) + a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) else: + z = self.model.encode(obs, task) a = self.model.pi(z, task)[int(not eval_mode)][0] return a.cpu() @@ -98,15 +151,16 @@ class TDMPC2: for t in range(self.cfg.horizon): reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) z = self.model.next(z, actions[t], task) - G += discount * reward - discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount + G = G + discount * reward + discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount + discount = discount * discount_update return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') @torch.no_grad() - def plan(self, z, t0=False, eval_mode=False, task=None): + def _plan(self, obs, t0=False, eval_mode=False, task=None): """ Plan a sequence of actions using the learned world model. - + Args: z (torch.Tensor): Latent state from which to plan. t0 (bool): Whether this is the first observation in the episode. @@ -115,8 +169,9 @@ class TDMPC2: Returns: torch.Tensor: Action to take in the environment. - """ + """ # Sample policy trajectories + z = self.model.encode(obs, task) if self.cfg.num_pi_trajs > 0: pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) _z = z.repeat(self.cfg.num_pi_trajs, 1) @@ -128,52 +183,53 @@ class TDMPC2: # Initialize state and parameters z = z.repeat(self.cfg.num_samples, 1) mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) - std = self.cfg.max_std*torch.ones(self.cfg.horizon, self.cfg.action_dim, device=self.device) + std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) if not t0: mean[:-1] = self._prev_mean[1:] actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) if self.cfg.num_pi_trajs > 0: actions[:, :self.cfg.num_pi_trajs] = pi_actions - + # Iterate MPPI for _ in range(self.cfg.iterations): # Sample actions - actions[:, self.cfg.num_pi_trajs:] = (mean.unsqueeze(1) + std.unsqueeze(1) * \ - torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)) \ - .clamp(-1, 1) + r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) + actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r + actions_sample = actions_sample.clamp(-1, 1) + actions[:, self.cfg.num_pi_trajs:] = actions_sample if self.cfg.multitask: actions = actions * self.model._action_masks[task] # Compute elite actions - value = self._estimate_value(z, actions, task).nan_to_num_(0) + value = self._estimate_value(z, actions, task).nan_to_num(0) elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] # Update parameters - max_value = elite_value.max(0)[0] + max_value = elite_value.max(0).values score = torch.exp(self.cfg.temperature*(elite_value - max_value)) - score /= score.sum(0) - mean = torch.sum(score.unsqueeze(0) * elite_actions, dim=1) / (score.sum(0) + 1e-9) - std = torch.sqrt(torch.sum(score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2, dim=1) / (score.sum(0) + 1e-9)) \ - .clamp_(self.cfg.min_std, self.cfg.max_std) + score = score / score.sum(0) + mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9) + std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt() + std = std.clamp(self.cfg.min_std, self.cfg.max_std) if self.cfg.multitask: mean = mean * self.model._action_masks[task] std = std * self.model._action_masks[task] # Select action - score = score.squeeze(1).cpu().numpy() - actions = elite_actions[:, np.random.choice(np.arange(score.shape[0]), p=score)] - self._prev_mean = mean + rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs + actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) a, std = actions[0], std[0] if not eval_mode: - a += std * torch.randn(self.cfg.action_dim, device=std.device) - return a.clamp_(-1, 1) - + a = a + std * torch.randn(self.cfg.action_dim, device=std.device) + self._prev_mean.copy_(mean) + return a.clamp(-1, 1) + def update_pi(self, zs, task): """ Update policy using a sequence of latent states. - + Args: zs (torch.Tensor): Sequence of latent states. task (torch.Tensor): Task index (only used for multi-task experiments). @@ -181,10 +237,8 @@ class TDMPC2: Returns: float: Loss of the policy update. """ - self.pi_optim.zero_grad(set_to_none=True) - self.model.track_q_grad(False) _, pis, log_pis, _ = self.model.pi(zs, task) - qs = self.model.Q(zs, pis, task, return_type='avg') + qs = self.model.Q(zs, pis, task, return_type='avg', detach=True) self.scale.update(qs[0]) qs = self.scale(qs) @@ -192,22 +246,23 @@ class TDMPC2: rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() pi_loss.backward() - torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) + pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) self.pi_optim.step() - self.model.track_q_grad(True) + # For some reason, cudagraph prefers to see the zero grad after step + self.pi_optim.zero_grad(set_to_none=True) - return pi_loss.item() + return pi_loss.detach(), pi_grad_norm @torch.no_grad() def _td_target(self, next_z, reward, task): """ Compute the TD-target from a reward and the observation at the following time step. - + Args: next_z (torch.Tensor): Latent state at the following time step. reward (torch.Tensor): Reward at the current time step. task (torch.Tensor): Task index (only used for multi-task experiments). - + Returns: torch.Tensor: TD-target. """ @@ -218,22 +273,28 @@ class TDMPC2: def update(self, buffer): """ Main update function. Corresponds to one iteration of model learning. - + Args: buffer (common.buffer.Buffer): Replay buffer. - + Returns: dict: Dictionary of training statistics. """ - obs, action, reward, task = buffer.sample() - + with timeit("sample"): + obs, action, reward, task = buffer.sample() + kwargs = {} + if task is not None: + kwargs["task"] = task + torch.compiler.cudagraph_mark_step_begin() + return self._update(obs, action, reward, **kwargs) + + def _update(self, obs, action, reward, task=None): # Compute targets with torch.no_grad(): next_z = self.model.encode(obs[1:], task) td_targets = self._td_target(next_z, reward, task) # Prepare for update - self.optim.zero_grad(set_to_none=True) self.model.train() # Latent rollout @@ -241,25 +302,26 @@ class TDMPC2: z = self.model.encode(obs[0], task) zs[0] = z consistency_loss = 0 - for t in range(self.cfg.horizon): - z = self.model.next(z, action[t], task) - consistency_loss += F.mse_loss(z, next_z[t]) * self.cfg.rho**t + for t, (_action, _next_z) in enumerate(zip(action.unbind(0), next_z.unbind(0))): + z = self.model.next(z, _action, task) + consistency_loss = consistency_loss + F.mse_loss(z, _next_z) * self.cfg.rho**t zs[t+1] = z # Predictions _zs = zs[:-1] qs = self.model.Q(_zs, action, task, return_type='all') reward_preds = self.model.reward(_zs, action, task) - + # Compute losses reward_loss, value_loss = 0, 0 - for t in range(self.cfg.horizon): - reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t - for q in range(self.cfg.num_q): - value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t - consistency_loss *= (1/self.cfg.horizon) - reward_loss *= (1/self.cfg.horizon) - value_loss *= (1/(self.cfg.horizon * self.cfg.num_q)) + for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))): + reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t + for q, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)): + value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t + + consistency_loss = consistency_loss / self.cfg.horizon + reward_loss = reward_loss / self.cfg.horizon + value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q) total_loss = ( self.cfg.consistency_coef * consistency_loss + self.cfg.reward_coef * reward_loss + @@ -270,21 +332,23 @@ class TDMPC2: total_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm) self.optim.step() + self.optim.zero_grad(set_to_none=True) # Update policy - pi_loss = self.update_pi(zs.detach(), task) + pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task) # Update target Q-functions self.model.soft_update_target_Q() # Return training statistics self.model.eval() - return { - "consistency_loss": float(consistency_loss.mean().item()), - "reward_loss": float(reward_loss.mean().item()), - "value_loss": float(value_loss.mean().item()), + return TensorDict({ + "consistency_loss": consistency_loss, + "reward_loss": reward_loss, + "value_loss": value_loss, "pi_loss": pi_loss, - "total_loss": float(total_loss.mean().item()), - "grad_norm": float(grad_norm), - "pi_scale": float(self.scale.value), - } + "total_loss": total_loss, + "grad_norm": grad_norm, + "pi_grad_norm": pi_grad_norm, + "pi_scale": self.scale.value, + }).detach().mean() diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 5953bb2..eb1249a 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -1,6 +1,8 @@ import os os.environ['MUJOCO_GL'] = 'egl' os.environ['LAZY_LEGACY_OP'] = '0' +os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1" +os.environ['TORCH_LOGS'] = "+recompiles" import warnings warnings.filterwarnings('ignore') import torch @@ -16,9 +18,27 @@ from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer from trainer.online_trainer import OnlineTrainer from common.logger import Logger - +import dataclasses +from typing import Any +from omegaconf import OmegaConf torch.backends.cudnn.benchmark = True +torch.set_float32_matmul_precision('high') + +def cfg_to_dataclass(cfg, frozen=False): + # Converts an OmegaConf config to a dataclass, which will not cause graph breaks + cfg_dict = OmegaConf.to_container(cfg) + fields = [] + for key, value in cfg_dict.items(): + fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_))) + + # Create the dataclass + dataclass_name = "Config" + dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen) + def get(self, val, default=None): + return getattr(self, val, default) + dataclass.get = get + return dataclass() @hydra.main(config_name='config', config_path='.') def train(cfg: dict): @@ -47,6 +67,9 @@ def train(cfg: dict): print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer + + cfg = cfg_to_dataclass(cfg) + trainer = trainer_cls( cfg=cfg, env=make_env(cfg), diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index a3326bc..f3072b5 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -3,7 +3,7 @@ from time import time import numpy as np import torch from tensordict.tensordict import TensorDict - +from torchrl._utils import timeit from trainer.base import Trainer @@ -57,61 +57,64 @@ class OnlineTrainer(Trainer): action = torch.full_like(self.env.rand_act(), float('nan')) if reward is None: reward = torch.tensor(float('nan')) - td = TensorDict(dict( + td = TensorDict( obs=obs, action=action.unsqueeze(0), reward=reward.unsqueeze(0), - ), batch_size=(1,)) + batch_size=(1,)) return td def train(self): """Train a TD-MPC2 agent.""" - train_metrics, done, eval_next = {}, True, True + train_metrics, done, eval_next = {}, True, False while self._step <= self.cfg.steps: + with timeit("global-step"): + # Evaluate agent periodically + if self._step > 0 and self._step % self.cfg.eval_freq == 0: + eval_next = True - # Evaluate agent periodically - if self._step % self.cfg.eval_freq == 0: - eval_next = True + # Reset environment + if done or (self._step == self.cfg.seed_steps + 1): + if eval_next: + eval_metrics = self.eval() + eval_metrics.update(self.common_metrics()) + self.logger.log(eval_metrics, 'eval') + eval_next = False - # Reset environment - if done: - if eval_next: - eval_metrics = self.eval() - eval_metrics.update(self.common_metrics()) - self.logger.log(eval_metrics, 'eval') - eval_next = False + if self._step > 0: + train_metrics.update( + episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), + episode_success=info['success'], + ) + train_metrics.update(self.common_metrics()) + train_metrics.update(timeit.todict()) + self.logger.log(train_metrics, 'train') + self._ep_idx = self.buffer.add(torch.cat(self._tds)) - if self._step > 0: - train_metrics.update( - episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), - episode_success=info['success'], - ) - train_metrics.update(self.common_metrics()) - self.logger.log(train_metrics, 'train') - self._ep_idx = self.buffer.add(torch.cat(self._tds)) + obs = self.env.reset() + self._tds = [self.to_td(obs)] - obs = self.env.reset() - self._tds = [self.to_td(obs)] + # Collect experience + with timeit("act"): + if self._step > self.cfg.seed_steps: + action = self.agent.act(obs, t0=len(self._tds)==1) + else: + action = self.env.rand_act() + obs, reward, done, info = self.env.step(action) + self._tds.append(self.to_td(obs, action, reward)) - # Collect experience - if self._step > self.cfg.seed_steps: - action = self.agent.act(obs, t0=len(self._tds)==1) - else: - action = self.env.rand_act() - obs, reward, done, info = self.env.step(action) - self._tds.append(self.to_td(obs, action, reward)) + # Update agent + if self._step >= self.cfg.seed_steps: + if self._step == self.cfg.seed_steps: + num_updates = self.cfg.seed_steps + print('Pretraining agent on seed data...') + else: + num_updates = 1 + for _ in range(num_updates): + with timeit("update"): + _train_metrics = self.agent.update(self.buffer) + train_metrics.update(_train_metrics) - # Update agent - if self._step >= self.cfg.seed_steps: - if self._step == self.cfg.seed_steps: - num_updates = self.cfg.seed_steps - print('Pretraining agent on seed data...') - else: - num_updates = 1 - for _ in range(num_updates): - _train_metrics = self.agent.update(self.buffer) - train_metrics.update(_train_metrics) + self._step += 1 - self._step += 1 - self.logger.finish(self.agent) From 970792e2b6477956582705e6b43f994bd4daafb7 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Fri, 18 Oct 2024 15:31:25 -0700 Subject: [PATCH 03/12] clean up prints --- requirements.txt | 34 -------------- tdmpc2/common/logger.py | 3 -- tdmpc2/common/world_model.py | 10 +++- tdmpc2/tdmpc2.py | 4 +- tdmpc2/trainer/base.py | 1 - tdmpc2/trainer/online_trainer.py | 81 +++++++++++++++----------------- 6 files changed, 47 insertions(+), 86 deletions(-) delete mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index f403b3e..0000000 --- a/requirements.txt +++ /dev/null @@ -1,34 +0,0 @@ -absl-py -cython -dm-control -ffmpeg -glfw -hydra-core -hydra-submitit-launcher -imageio -imageio-ffmpeg -kornia -moviepy -mujoco -mujoco-py -numpy<2 -omegaconf -open3d -opencv-contrib-python -opencv-python -pandas -sapien -submitit -setuptools -patchelf -protobuf -pillow -pyquaternion -tensordict-nightly -termcolor -torchrl-nightly -transforms3d -trimesh -tqdm -wandb -wheel diff --git a/tdmpc2/common/logger.py b/tdmpc2/common/logger.py index b3047f9..bdbb998 100755 --- a/tdmpc2/common/logger.py +++ b/tdmpc2/common/logger.py @@ -5,7 +5,6 @@ import re import numpy as np import pandas as pd from termcolor import colored -from torchrl._utils import timeit from common import TASK_SET @@ -238,5 +237,3 @@ class Logger: self._log_dir / "eval.csv", header=keys, index=None ) self._print(d, category) - timeit.print() - timeit.erase() diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 65572dd..eb9633d 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -1,11 +1,9 @@ from copy import deepcopy -import numpy as np import torch import torch.nn as nn from common import layers, math, init -from tensordict import TensorDict from tensordict.nn import TensorDictParams class WorldModel(nn.Module): @@ -48,6 +46,14 @@ class WorldModel(nn.Module): self._detach_Qs.params = self._detach_Qs_params self._target_Qs.params = self._target_Qs_params + def __repr__(self): + repr = 'TD-MPC2 World Model\n' + modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions'] + for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]): + repr += f"{modules[i]}: {m}\n" + repr += "Learnable parameters: {:,}".format(self.total_params) + return repr + @property def total_params(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 6ce3d7f..a4da1db 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -1,7 +1,6 @@ import torch import torch.nn.functional as F import functools -from torchrl._utils import timeit from common import math from common.scale import RunningScale @@ -280,8 +279,7 @@ class TDMPC2(torch.nn.Module): Returns: dict: Dictionary of training statistics. """ - with timeit("sample"): - obs, action, reward, task = buffer.sample() + obs, action, reward, task = buffer.sample() kwargs = {} if task is not None: kwargs["task"] = task diff --git a/tdmpc2/trainer/base.py b/tdmpc2/trainer/base.py index 27a328d..6d14783 100755 --- a/tdmpc2/trainer/base.py +++ b/tdmpc2/trainer/base.py @@ -8,7 +8,6 @@ class Trainer: self.buffer = buffer self.logger = logger print('Architecture:', self.agent.model) - print("Learnable parameters: {:,}".format(self.agent.model.total_params)) def eval(self): """Evaluate a TD-MPC2 agent.""" diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index f3072b5..103d129 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -3,7 +3,6 @@ from time import time import numpy as np import torch from tensordict.tensordict import TensorDict -from torchrl._utils import timeit from trainer.base import Trainer @@ -68,53 +67,49 @@ class OnlineTrainer(Trainer): """Train a TD-MPC2 agent.""" train_metrics, done, eval_next = {}, True, False while self._step <= self.cfg.steps: - with timeit("global-step"): - # Evaluate agent periodically - if self._step > 0 and self._step % self.cfg.eval_freq == 0: - eval_next = True + # Evaluate agent periodically + if self._step > 0 and self._step % self.cfg.eval_freq == 0: + eval_next = True - # Reset environment - if done or (self._step == self.cfg.seed_steps + 1): - if eval_next: - eval_metrics = self.eval() - eval_metrics.update(self.common_metrics()) - self.logger.log(eval_metrics, 'eval') - eval_next = False + # Reset environment + if done or (self._step == self.cfg.seed_steps + 1): + if eval_next: + eval_metrics = self.eval() + eval_metrics.update(self.common_metrics()) + self.logger.log(eval_metrics, 'eval') + eval_next = False - if self._step > 0: - train_metrics.update( - episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), - episode_success=info['success'], - ) - train_metrics.update(self.common_metrics()) - train_metrics.update(timeit.todict()) - self.logger.log(train_metrics, 'train') - self._ep_idx = self.buffer.add(torch.cat(self._tds)) + if self._step > 0: + train_metrics.update( + episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), + episode_success=info['success'], + ) + train_metrics.update(self.common_metrics()) + self.logger.log(train_metrics, 'train') + self._ep_idx = self.buffer.add(torch.cat(self._tds)) - obs = self.env.reset() - self._tds = [self.to_td(obs)] + obs = self.env.reset() + self._tds = [self.to_td(obs)] - # Collect experience - with timeit("act"): - if self._step > self.cfg.seed_steps: - action = self.agent.act(obs, t0=len(self._tds)==1) - else: - action = self.env.rand_act() - obs, reward, done, info = self.env.step(action) - self._tds.append(self.to_td(obs, action, reward)) + # Collect experience + if self._step > self.cfg.seed_steps: + action = self.agent.act(obs, t0=len(self._tds)==1) + else: + action = self.env.rand_act() + obs, reward, done, info = self.env.step(action) + self._tds.append(self.to_td(obs, action, reward)) - # Update agent - if self._step >= self.cfg.seed_steps: - if self._step == self.cfg.seed_steps: - num_updates = self.cfg.seed_steps - print('Pretraining agent on seed data...') - else: - num_updates = 1 - for _ in range(num_updates): - with timeit("update"): - _train_metrics = self.agent.update(self.buffer) - train_metrics.update(_train_metrics) + # Update agent + if self._step >= self.cfg.seed_steps: + if self._step == self.cfg.seed_steps: + num_updates = self.cfg.seed_steps + print('Pretraining agent on seed data...') + else: + num_updates = 1 + for _ in range(num_updates): + _train_metrics = self.agent.update(self.buffer) + train_metrics.update(_train_metrics) - self._step += 1 + self._step += 1 self.logger.finish(self.agent) From 836547d76fe0d35a625fa8b5dc59660137a92e77 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Mon, 21 Oct 2024 14:49:21 -0700 Subject: [PATCH 04/12] fix eval index + clean up --- README.md | 9 ++++ tdmpc2/common/logger.py | 4 +- tdmpc2/common/math.py | 6 +-- tdmpc2/config.yaml | 5 +- tdmpc2/tdmpc2.py | 79 +++++++++----------------------- tdmpc2/trainer/online_trainer.py | 4 +- 6 files changed, 39 insertions(+), 68 deletions(-) diff --git a/README.md b/README.md index 687b73d..0e74c2f 100755 --- a/README.md +++ b/README.md @@ -12,6 +12,15 @@ Official implementation of ---- +**Note: the `speedups` branch is experimental and may contain bugs. Please use the `main` branch for the latest stable release.** + +Expect **3-8x** faster wall-time (depending on hardware and task) compared to `main` branch. A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. We are not responsible for any issues that may arise from using our repository. + +Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to this branch! + +---- + + ## Overview TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across **104** continuous control tasks spanning multiple domains, with a *single* set of hyperparameters (*right*). We further demonstrate the scalability of TD-MPC**2** by training a single 317M parameter agent to perform **80** tasks across multiple domains, embodiments, and action spaces (*left*). diff --git a/tdmpc2/common/logger.py b/tdmpc2/common/logger.py index bdbb998..8ea2c2e 100755 --- a/tdmpc2/common/logger.py +++ b/tdmpc2/common/logger.py @@ -2,9 +2,11 @@ import dataclasses import os import datetime import re + import numpy as np import pandas as pd from termcolor import colored + from common import TASK_SET @@ -115,7 +117,7 @@ class Logger: print_run(cfg) self.project = cfg.get("wandb_project", "none") self.entity = cfg.get("wandb_entity", "none") - if cfg.disable_wandb or self.project == "none" or self.entity == "none": + if not cfg.enable_wandb or self.project == "none" or self.entity == "none": print(colored("Wandb disabled.", "blue", attrs=["bold"])) cfg.save_agent = False cfg.save_video = False diff --git a/tdmpc2/common/math.py b/tdmpc2/common/math.py index 91fcdce..5ac92ad 100644 --- a/tdmpc2/common/math.py +++ b/tdmpc2/common/math.py @@ -9,12 +9,10 @@ def soft_ce(pred, target, cfg): return -(target * pred).sum(-1, keepdim=True) - def log_std(x, low, dif): return low + 0.5 * dif * (torch.tanh(x) + 1) - def _gaussian_residual(eps, log_std): return -0.5 * eps.pow(2) - log_std @@ -32,7 +30,6 @@ def gaussian_logprob(eps, log_std, size=None): return _gaussian_logprob(residual) * size - def _squash(pi): return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) @@ -45,7 +42,6 @@ def squash(mu, pi, log_pi): return mu, pi, log_pi - def symlog(x): """ Symmetric logarithmic function. @@ -54,7 +50,6 @@ def symlog(x): return torch.sign(x) * torch.log(1 + torch.abs(x)) - def symexp(x): """ Symmetric exponential function. @@ -90,6 +85,7 @@ def two_hot_inv(x, cfg): x = torch.sum(x * dreg_bins, dim=-1, keepdim=True) return symexp(x) + def gumbel_softmax_sample(p, temperature=1.0, dim=0): logits = p.log() # Generate Gumbel noise diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index 441e421..597c829 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -65,7 +65,7 @@ simnorm_dim: 8 wandb_project: ??? wandb_entity: ??? wandb_silent: false -disable_wandb: true +enable_wandb: true save_csv: true # misc @@ -87,6 +87,5 @@ episode_lengths: ??? seed_steps: ??? bin_size: ??? -# compile +# speedups compile: False -cudagraphs: False diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index a4da1db..8e6b435 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -1,15 +1,11 @@ import torch import torch.nn.functional as F -import functools from common import math from common.scale import RunningScale from common.world_model import WorldModel -from tensordict.nn import CudaGraphModule from tensordict import TensorDict -CG_WARMUP = 1000 - class TDMPC2(torch.nn.Module): """ @@ -21,11 +17,8 @@ class TDMPC2(torch.nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg - self.device = torch.device('cuda:0') - self.model = WorldModel(cfg).to(self.device) - capturable = True self.optim = torch.optim.Adam([ {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._dynamics.parameters()}, @@ -33,8 +26,8 @@ class TDMPC2(torch.nn.Module): {'params': self.model._Qs.parameters()}, {'params': self.model._task_emb.parameters() if self.cfg.multitask else [] } - ], lr=self.cfg.lr, capturable=capturable) - self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=capturable) + ], lr=self.cfg.lr, capturable=True) + self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=True) self.model.eval() self.scale = RunningScale(cfg) self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces @@ -42,43 +35,16 @@ class TDMPC2(torch.nn.Module): [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' ) if self.cfg.multitask else self._get_discount(cfg.episode_length) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) - if cfg.compile: - mode = None if cfg.cudagraphs else "reduce-overhead" print('compiling - update') - self._update = torch.compile(self._update, mode=mode) - - if cfg.cudagraphs: - print('cudagraphs - update') - self._update = CudaGraphModule(self._update, warmup=CG_WARMUP) + self._update = torch.compile(self._update, mode="reduce-overhead") @property def plan(self): _plan_val = getattr(self, "_plan_val", None) if _plan_val is not None: return _plan_val - if self.cfg.cudagraphs: - print('cudagraphs - plan') - self._plan_dict = { - (True, True): functools.partial(self._plan, t0=True, eval_mode=True), - (False, True): functools.partial(self._plan, t0=False, eval_mode=True), - (True, False): functools.partial(self._plan, t0=True, eval_mode=False), - (False, False): functools.partial(self._plan, t0=False, eval_mode=False), - } - if self.cfg.compile: - print('compiling - plan') - mode = None - self._plan_dict = {k: torch.compile(func, mode=mode) for k, func in self._plan_dict.items()} - - self._plan_dict = {k: CudaGraphModule(func, warmup=CG_WARMUP) for k, func in self._plan_dict.items()} - def plan(obs, t0=False, eval_mode=False, task=None): - if task is not None: - kwargs = {"task": task} - else: - kwargs = {} - torch.compiler.cudagraph_mark_step_begin() - return self._plan_dict[(t0, eval_mode)](obs=obs, **kwargs) - elif self.cfg.compile: + if self.cfg.compile: plan = torch.compile(self._plan, mode="reduce-overhead") else: plan = self._plan @@ -247,7 +213,6 @@ class TDMPC2(torch.nn.Module): pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) self.pi_optim.step() - # For some reason, cudagraph prefers to see the zero grad after step self.pi_optim.zero_grad(set_to_none=True) return pi_loss.detach(), pi_grad_norm @@ -269,23 +234,6 @@ class TDMPC2(torch.nn.Module): discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) - def update(self, buffer): - """ - Main update function. Corresponds to one iteration of model learning. - - Args: - buffer (common.buffer.Buffer): Replay buffer. - - Returns: - dict: Dictionary of training statistics. - """ - obs, action, reward, task = buffer.sample() - kwargs = {} - if task is not None: - kwargs["task"] = task - torch.compiler.cudagraph_mark_step_begin() - return self._update(obs, action, reward, **kwargs) - def _update(self, obs, action, reward, task=None): # Compute targets with torch.no_grad(): @@ -314,7 +262,7 @@ class TDMPC2(torch.nn.Module): reward_loss, value_loss = 0, 0 for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))): reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t - for q, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)): + for _, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)): value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t consistency_loss = consistency_loss / self.cfg.horizon @@ -350,3 +298,20 @@ class TDMPC2(torch.nn.Module): "pi_grad_norm": pi_grad_norm, "pi_scale": self.scale.value, }).detach().mean() + + def update(self, buffer): + """ + Main update function. Corresponds to one iteration of model learning. + + Args: + buffer (common.buffer.Buffer): Replay buffer. + + Returns: + dict: Dictionary of training statistics. + """ + obs, action, reward, task = buffer.sample() + kwargs = {} + if task is not None: + kwargs["task"] = task + torch.compiler.cudagraph_mark_step_begin() + return self._update(obs, action, reward, **kwargs) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 103d129..3a47542 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -68,11 +68,11 @@ class OnlineTrainer(Trainer): train_metrics, done, eval_next = {}, True, False while self._step <= self.cfg.steps: # Evaluate agent periodically - if self._step > 0 and self._step % self.cfg.eval_freq == 0: + if self._step % self.cfg.eval_freq == 0: eval_next = True # Reset environment - if done or (self._step == self.cfg.seed_steps + 1): + if done: if eval_next: eval_metrics = self.eval() eval_metrics.update(self.common_metrics()) From fad0d1be0361d15997c9cf8594f1f47dad226ca4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 26 Oct 2024 00:32:16 +0100 Subject: [PATCH 05/12] Use torch.compiler.cudagraph_mark_step_begin() in eval --- tdmpc2/trainer/online_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 3a47542..0d2f062 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -31,6 +31,7 @@ class OnlineTrainer(Trainer): if self.cfg.save_video: self.logger.video.init(self.env, enabled=(i==0)) while not done: + torch.compiler.cudagraph_mark_step_begin() action = self.agent.act(obs, t0=t==0, eval_mode=True) obs, reward, done, info = self.env.step(action) ep_reward += reward From 3b5f67592ccd8d4c511bea490e97b33cd67a7a7b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 26 Oct 2024 00:33:27 +0100 Subject: [PATCH 06/12] Update offline_trainer.py --- tdmpc2/trainer/offline_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tdmpc2/trainer/offline_trainer.py b/tdmpc2/trainer/offline_trainer.py index 1bace8e..89f1c20 100755 --- a/tdmpc2/trainer/offline_trainer.py +++ b/tdmpc2/trainer/offline_trainer.py @@ -27,6 +27,7 @@ class OfflineTrainer(Trainer): for _ in range(self.cfg.eval_episodes): obs, done, ep_reward, t = self.env.reset(task_idx), False, 0, 0 while not done: + torch.compiler.cudagraph_mark_step_begin() action = self.agent.act(obs, t0=t==0, eval_mode=True, task=task_idx) obs, reward, done, info = self.env.step(action) ep_reward += reward From c1dd0c0338dbb2acc95ee506cf1907087311ca44 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 27 Oct 2024 14:24:19 -0700 Subject: [PATCH 07/12] minor QoL improvements in offline pipeline --- datasets/download_mt30.sh | 1 + datasets/download_mt80.sh | 1 + docker/environment.yaml | 2 +- tdmpc2/common/parser.py | 2 +- tdmpc2/trainer/offline_trainer.py | 10 +++++----- 5 files changed, 9 insertions(+), 7 deletions(-) create mode 100644 datasets/download_mt30.sh create mode 100644 datasets/download_mt80.sh diff --git a/datasets/download_mt30.sh b/datasets/download_mt30.sh new file mode 100644 index 0000000..2073bcb --- /dev/null +++ b/datasets/download_mt30.sh @@ -0,0 +1 @@ +for i in {0..3}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt30/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done \ No newline at end of file diff --git a/datasets/download_mt80.sh b/datasets/download_mt80.sh new file mode 100644 index 0000000..01a7c46 --- /dev/null +++ b/datasets/download_mt80.sh @@ -0,0 +1 @@ +for i in {0..19}; do wget https://huggingface.co/datasets/nicklashansen/tdmpc2/resolve/main/mt80/chunk_${i}.pt?download=true && mv chunk_${i}.pt'?download=true' chunk_${i}.pt; done \ No newline at end of file diff --git a/docker/environment.yaml b/docker/environment.yaml index 857c81a..9425459 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -1,4 +1,4 @@ -name: graph +name: tdmpc2 channels: - pytorch-nightly - nvidia diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index ddce2b4..378ba4a 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -53,7 +53,7 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: if cfg.multitask: cfg.task_title = cfg.task.upper() # Account for slight inconsistency in task_dim for the mt30 experiments - cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.model_size in {1, 317} else 64 + cfg.task_dim = 96 if cfg.task == 'mt80' or cfg.get('model_size', 5) in {1, 317} else 64 else: cfg.task_dim = 0 cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) diff --git a/tdmpc2/trainer/offline_trainer.py b/tdmpc2/trainer/offline_trainer.py index 1bace8e..a4289d9 100755 --- a/tdmpc2/trainer/offline_trainer.py +++ b/tdmpc2/trainer/offline_trainer.py @@ -44,13 +44,12 @@ class OfflineTrainer(Trainer): 'Offline training only supports multitask training with mt30 or mt80 task sets.' # Load data - assert self.cfg.task in self.cfg.data_dir, \ - f'Expected data directory {self.cfg.data_dir} to contain {self.cfg.task}, ' \ - f'please double-check your config.' fp = Path(os.path.join(self.cfg.data_dir, '*.pt')) fps = sorted(glob(str(fp))) assert len(fps) > 0, f'No data found at {fp}' print(f'Found {len(fps)} files in {fp}') + assert len(fps) == (20 if self.cfg.task == 'mt80' else 4), \ + f'Expected 20 files for mt80 task set, 4 files for mt30 task set, found {len(fps)} files.' # Create buffer for sampling _cfg = deepcopy(self.cfg) @@ -65,8 +64,9 @@ class OfflineTrainer(Trainer): f'please double-check your config.' for i in range(len(td)): self.buffer.add(td[i]) - assert self.buffer.num_eps == self.buffer.capacity, \ - f'Buffer has {self.buffer.num_eps} episodes, expected {self.buffer.capacity} episodes.' + expected_episodes = _cfg.buffer_size // _cfg.episode_length + assert self.buffer.num_eps == expected_episodes, \ + f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.' print(f'Training agent for {self.cfg.steps} iterations...') metrics = {} From b7725e74a5320b3046bf604857d983737ed49393 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Thu, 31 Oct 2024 14:52:59 -0700 Subject: [PATCH 08/12] move cfg conversion to parser.py --- tdmpc2/common/parser.py | 21 ++++++++++++++++++++- tdmpc2/train.py | 36 ------------------------------------ 2 files changed, 20 insertions(+), 37 deletions(-) diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index 378ba4a..a8d9f25 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -1,5 +1,7 @@ +import dataclasses import re from pathlib import Path +from typing import Any import hydra from omegaconf import OmegaConf @@ -7,6 +9,23 @@ from omegaconf import OmegaConf from common import MODEL_SIZE, TASK_SET +def cfg_to_dataclass(cfg, frozen=False): + """ + Converts an OmegaConf config to a dataclass object. + This prevents graph breaks when used with torch.compile. + """ + cfg_dict = OmegaConf.to_container(cfg) + fields = [] + for key, value in cfg_dict.items(): + fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_))) + dataclass_name = "Config" + dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen) + def get(self, val, default=None): + return getattr(self, val, default) + dataclass.get = get + return dataclass() + + def parse_cfg(cfg: OmegaConf) -> OmegaConf: """ Parses a Hydra config. Mostly for convenience. @@ -58,4 +77,4 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.task_dim = 0 cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) - return cfg + return cfg_to_dataclass(cfg) diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 48206ec..3dc37a6 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -18,42 +18,9 @@ from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer from trainer.online_trainer import OnlineTrainer from common.logger import Logger -import dataclasses -from typing import Any -from omegaconf import OmegaConf torch.backends.cudnn.benchmark = True - torch.set_float32_matmul_precision('high') -def cfg_to_dataclass(cfg, frozen=False): - # Converts an OmegaConf config to a dataclass, which will not cause graph breaks - cfg_dict = OmegaConf.to_container(cfg) - fields = [] - for key, value in cfg_dict.items(): - fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_))) - - # Create the dataclass - dataclass_name = "Config" - dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen) - def get(self, val, default=None): - return getattr(self, val, default) - dataclass.get = get - return dataclass() - -def cfg_to_dataclass(cfg, frozen=False): - # Converts an OmegaConf config to a dataclass, which will not cause graph breaks - cfg_dict = OmegaConf.to_container(cfg) - fields = [] - for key, value in cfg_dict.items(): - fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_))) - - # Create the dataclass - dataclass_name = "Config" - dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen) - def get(self, val, default=None): - return getattr(self, val, default) - dataclass.get = get - return dataclass() @hydra.main(config_name='config', config_path='.') def train(cfg: dict): @@ -82,9 +49,6 @@ def train(cfg: dict): print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer - - cfg = cfg_to_dataclass(cfg) - trainer = trainer_cls( cfg=cfg, env=make_env(cfg), From 1a7720764616e4e55621a970e87e86fee7ef60ed Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Mon, 4 Nov 2024 15:15:40 -0800 Subject: [PATCH 09/12] support newest version of myosuite --- docker/environment.yaml | 3 +-- tdmpc2/envs/myosuite.py | 9 ++++++--- tdmpc2/train.py | 1 + 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/docker/environment.yaml b/docker/environment.yaml index 9425459..9da54e0 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -39,7 +39,7 @@ dependencies: - gpustat==1.1.1 #################### # Gym: - # (unmaintained but required for maniskill2/meta-world/myosuite) + # (unmaintained but required for maniskill2/meta-world) - gym==0.21.0 #################### # ManiSkill2: @@ -51,6 +51,5 @@ dependencies: # - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb #################### # MyoSuite: - # (requires gym==0.13 which conflicts with meta-world / mani-skill2) # - myosuite #################### diff --git a/tdmpc2/envs/myosuite.py b/tdmpc2/envs/myosuite.py index fa6876e..d15f11f 100644 --- a/tdmpc2/envs/myosuite.py +++ b/tdmpc2/envs/myosuite.py @@ -24,9 +24,11 @@ class MyoSuiteWrapper(gym.Wrapper): self.cfg = cfg self.camera_id = 'hand_side_inter' + def reset(self): + return self.env.reset()[0] + def step(self, action): - obs, reward, _, info = self.env.step(action.copy()) - obs = obs.astype(np.float32) + obs, reward, _, _, info = self.env.step(action.copy()) info['success'] = info['solved'] return obs, reward, False, info @@ -48,7 +50,8 @@ def make_env(cfg): raise ValueError('Unknown task:', cfg.task) assert cfg.obs == 'state', 'This task only supports state observations.' import myosuite - env = gym.make(MYOSUITE_TASKS[cfg.task]) + from myosuite.utils import gym as gym_utils + env = gym_utils.make(MYOSUITE_TASKS[cfg.task]) env = MyoSuiteWrapper(env, cfg) env = TimeLimit(env, max_episode_steps=100) env.max_episode_steps = env._max_episode_steps diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 3dc37a6..1846145 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -18,6 +18,7 @@ from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer from trainer.online_trainer import OnlineTrainer from common.logger import Logger + torch.backends.cudnn.benchmark = True torch.set_float32_matmul_precision('high') From c694d286f04ca4133cc160227a5575913a9177d1 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 10 Nov 2024 12:25:43 -0800 Subject: [PATCH 10/12] add assertion for compile=true compatibility --- docker/environment.yaml | 9 +++++---- tdmpc2/common/parser.py | 5 +++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/docker/environment.yaml b/docker/environment.yaml index 9da54e0..87cab71 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -14,6 +14,7 @@ dependencies: - torchvision - pip: - absl-py==2.1.0 + - "cython<3" - dm-control==1.0.8 - glfw==2.7.0 - ffmpeg==1.4 @@ -23,6 +24,8 @@ dependencies: - hydra-core==1.3.2 - hydra-submitit-launcher==1.2.0 - submitit==1.5.1 + - setuptools==65.5.0 + - patchelf==0.17.2.1 - omegaconf==2.3.0 - moviepy==1.0.3 - mujoco==2.3.1 @@ -34,13 +37,11 @@ dependencies: - tqdm==4.66.4 - pandas==2.0.3 - wandb==0.17.4 - - matplotlib==3.7.5 - - seaborn==0.13.2 - - gpustat==1.1.1 + - wheel==0.38.0 #################### # Gym: # (unmaintained but required for maniskill2/meta-world) - - gym==0.21.0 + # - gym==0.21.0 #################### # ManiSkill2: # (requires gym==0.21.0 which occasionally breaks) diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index a8d9f25..e162eac 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -77,4 +77,9 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.task_dim = 0 cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) + # Check torch.compile compatibility + if cfg.get('compile', False): + assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.' + assert not cfg.multitask, 'torch.compile does not support multitask training at the moment.' + return cfg_to_dataclass(cfg) From fb07cdac3f0934a6c3977b79e25f0db312f7b695 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 10 Nov 2024 12:27:24 -0800 Subject: [PATCH 11/12] update compile print --- tdmpc2/tdmpc2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 8e6b435..e4d8ec2 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module): ) if self.cfg.multitask else self._get_discount(cfg.episode_length) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) if cfg.compile: - print('compiling - update') + print('Compiling update function with torch.compile...') self._update = torch.compile(self._update, mode="reduce-overhead") @property From 0c3fcc46190e3f5dd9a5be805989418b0d280483 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 10 Nov 2024 12:32:57 -0800 Subject: [PATCH 12/12] update readme --- README.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 0e74c2f..441ddd6 100755 --- a/README.md +++ b/README.md @@ -12,11 +12,9 @@ Official implementation of ---- -**Note: the `speedups` branch is experimental and may contain bugs. Please use the `main` branch for the latest stable release.** +**Announcement: training just got ~4.5x faster!** -Expect **3-8x** faster wall-time (depending on hardware and task) compared to `main` branch. A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. We are not responsible for any issues that may arise from using our repository. - -Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to this branch! +Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. `compile=true` is available in state-based online RL at the moment, and we expect to roll out support across all settings in the coming months. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility! ----