diff --git a/README.md b/README.md index 687b73d..0e74c2f 100755 --- a/README.md +++ b/README.md @@ -12,6 +12,15 @@ Official implementation of ---- +**Note: the `speedups` branch is experimental and may contain bugs. Please use the `main` branch for the latest stable release.** + +Expect **3-8x** faster wall-time (depending on hardware and task) compared to `main` branch. A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. We are not responsible for any issues that may arise from using our repository. + +Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to this branch! + +---- + + ## Overview TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across **104** continuous control tasks spanning multiple domains, with a *single* set of hyperparameters (*right*). We further demonstrate the scalability of TD-MPC**2** by training a single 317M parameter agent to perform **80** tasks across multiple domains, embodiments, and action spaces (*left*). diff --git a/tdmpc2/common/logger.py b/tdmpc2/common/logger.py index bdbb998..8ea2c2e 100755 --- a/tdmpc2/common/logger.py +++ b/tdmpc2/common/logger.py @@ -2,9 +2,11 @@ import dataclasses import os import datetime import re + import numpy as np import pandas as pd from termcolor import colored + from common import TASK_SET @@ -115,7 +117,7 @@ class Logger: print_run(cfg) self.project = cfg.get("wandb_project", "none") self.entity = cfg.get("wandb_entity", "none") - if cfg.disable_wandb or self.project == "none" or self.entity == "none": + if not cfg.enable_wandb or self.project == "none" or self.entity == "none": print(colored("Wandb disabled.", "blue", attrs=["bold"])) cfg.save_agent = False cfg.save_video = False diff --git a/tdmpc2/common/math.py b/tdmpc2/common/math.py index 91fcdce..5ac92ad 100644 --- a/tdmpc2/common/math.py +++ b/tdmpc2/common/math.py @@ -9,12 +9,10 @@ def soft_ce(pred, target, cfg): return -(target * pred).sum(-1, keepdim=True) - def log_std(x, low, dif): return low + 0.5 * dif * (torch.tanh(x) + 1) - def _gaussian_residual(eps, log_std): return -0.5 * eps.pow(2) - log_std @@ -32,7 +30,6 @@ def gaussian_logprob(eps, log_std, size=None): return _gaussian_logprob(residual) * size - def _squash(pi): return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) @@ -45,7 +42,6 @@ def squash(mu, pi, log_pi): return mu, pi, log_pi - def symlog(x): """ Symmetric logarithmic function. @@ -54,7 +50,6 @@ def symlog(x): return torch.sign(x) * torch.log(1 + torch.abs(x)) - def symexp(x): """ Symmetric exponential function. @@ -90,6 +85,7 @@ def two_hot_inv(x, cfg): x = torch.sum(x * dreg_bins, dim=-1, keepdim=True) return symexp(x) + def gumbel_softmax_sample(p, temperature=1.0, dim=0): logits = p.log() # Generate Gumbel noise diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index 441e421..597c829 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -65,7 +65,7 @@ simnorm_dim: 8 wandb_project: ??? wandb_entity: ??? wandb_silent: false -disable_wandb: true +enable_wandb: true save_csv: true # misc @@ -87,6 +87,5 @@ episode_lengths: ??? seed_steps: ??? bin_size: ??? -# compile +# speedups compile: False -cudagraphs: False diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index a4da1db..8e6b435 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -1,15 +1,11 @@ import torch import torch.nn.functional as F -import functools from common import math from common.scale import RunningScale from common.world_model import WorldModel -from tensordict.nn import CudaGraphModule from tensordict import TensorDict -CG_WARMUP = 1000 - class TDMPC2(torch.nn.Module): """ @@ -21,11 +17,8 @@ class TDMPC2(torch.nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg - self.device = torch.device('cuda:0') - self.model = WorldModel(cfg).to(self.device) - capturable = True self.optim = torch.optim.Adam([ {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._dynamics.parameters()}, @@ -33,8 +26,8 @@ class TDMPC2(torch.nn.Module): {'params': self.model._Qs.parameters()}, {'params': self.model._task_emb.parameters() if self.cfg.multitask else [] } - ], lr=self.cfg.lr, capturable=capturable) - self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=capturable) + ], lr=self.cfg.lr, capturable=True) + self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=True) self.model.eval() self.scale = RunningScale(cfg) self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces @@ -42,43 +35,16 @@ class TDMPC2(torch.nn.Module): [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' ) if self.cfg.multitask else self._get_discount(cfg.episode_length) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) - if cfg.compile: - mode = None if cfg.cudagraphs else "reduce-overhead" print('compiling - update') - self._update = torch.compile(self._update, mode=mode) - - if cfg.cudagraphs: - print('cudagraphs - update') - self._update = CudaGraphModule(self._update, warmup=CG_WARMUP) + self._update = torch.compile(self._update, mode="reduce-overhead") @property def plan(self): _plan_val = getattr(self, "_plan_val", None) if _plan_val is not None: return _plan_val - if self.cfg.cudagraphs: - print('cudagraphs - plan') - self._plan_dict = { - (True, True): functools.partial(self._plan, t0=True, eval_mode=True), - (False, True): functools.partial(self._plan, t0=False, eval_mode=True), - (True, False): functools.partial(self._plan, t0=True, eval_mode=False), - (False, False): functools.partial(self._plan, t0=False, eval_mode=False), - } - if self.cfg.compile: - print('compiling - plan') - mode = None - self._plan_dict = {k: torch.compile(func, mode=mode) for k, func in self._plan_dict.items()} - - self._plan_dict = {k: CudaGraphModule(func, warmup=CG_WARMUP) for k, func in self._plan_dict.items()} - def plan(obs, t0=False, eval_mode=False, task=None): - if task is not None: - kwargs = {"task": task} - else: - kwargs = {} - torch.compiler.cudagraph_mark_step_begin() - return self._plan_dict[(t0, eval_mode)](obs=obs, **kwargs) - elif self.cfg.compile: + if self.cfg.compile: plan = torch.compile(self._plan, mode="reduce-overhead") else: plan = self._plan @@ -247,7 +213,6 @@ class TDMPC2(torch.nn.Module): pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) self.pi_optim.step() - # For some reason, cudagraph prefers to see the zero grad after step self.pi_optim.zero_grad(set_to_none=True) return pi_loss.detach(), pi_grad_norm @@ -269,23 +234,6 @@ class TDMPC2(torch.nn.Module): discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) - def update(self, buffer): - """ - Main update function. Corresponds to one iteration of model learning. - - Args: - buffer (common.buffer.Buffer): Replay buffer. - - Returns: - dict: Dictionary of training statistics. - """ - obs, action, reward, task = buffer.sample() - kwargs = {} - if task is not None: - kwargs["task"] = task - torch.compiler.cudagraph_mark_step_begin() - return self._update(obs, action, reward, **kwargs) - def _update(self, obs, action, reward, task=None): # Compute targets with torch.no_grad(): @@ -314,7 +262,7 @@ class TDMPC2(torch.nn.Module): reward_loss, value_loss = 0, 0 for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))): reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t - for q, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)): + for _, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)): value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t consistency_loss = consistency_loss / self.cfg.horizon @@ -350,3 +298,20 @@ class TDMPC2(torch.nn.Module): "pi_grad_norm": pi_grad_norm, "pi_scale": self.scale.value, }).detach().mean() + + def update(self, buffer): + """ + Main update function. Corresponds to one iteration of model learning. + + Args: + buffer (common.buffer.Buffer): Replay buffer. + + Returns: + dict: Dictionary of training statistics. + """ + obs, action, reward, task = buffer.sample() + kwargs = {} + if task is not None: + kwargs["task"] = task + torch.compiler.cudagraph_mark_step_begin() + return self._update(obs, action, reward, **kwargs) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 103d129..3a47542 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -68,11 +68,11 @@ class OnlineTrainer(Trainer): train_metrics, done, eval_next = {}, True, False while self._step <= self.cfg.steps: # Evaluate agent periodically - if self._step > 0 and self._step % self.cfg.eval_freq == 0: + if self._step % self.cfg.eval_freq == 0: eval_next = True # Reset environment - if done or (self._step == self.cfg.seed_steps + 1): + if done: if eval_next: eval_metrics = self.eval() eval_metrics.update(self.common_metrics())