fix eval index + clean up

This commit is contained in:
Nicklas Hansen
2024-10-21 14:49:21 -07:00
parent 970792e2b6
commit 836547d76f
6 changed files with 39 additions and 68 deletions

View File

@@ -12,6 +12,15 @@ Official implementation of
---- ----
**Note: the `speedups` branch is experimental and may contain bugs. Please use the `main` branch for the latest stable release.**
Expect **3-8x** faster wall-time (depending on hardware and task) compared to `main` branch. A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. We are not responsible for any issues that may arise from using our repository.
Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to this branch!
----
## Overview ## Overview
TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across **104** continuous control tasks spanning multiple domains, with a *single* set of hyperparameters (*right*). We further demonstrate the scalability of TD-MPC**2** by training a single 317M parameter agent to perform **80** tasks across multiple domains, embodiments, and action spaces (*left*). TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across **104** continuous control tasks spanning multiple domains, with a *single* set of hyperparameters (*right*). We further demonstrate the scalability of TD-MPC**2** by training a single 317M parameter agent to perform **80** tasks across multiple domains, embodiments, and action spaces (*left*).

View File

@@ -2,9 +2,11 @@ import dataclasses
import os import os
import datetime import datetime
import re import re
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from termcolor import colored from termcolor import colored
from common import TASK_SET from common import TASK_SET
@@ -115,7 +117,7 @@ class Logger:
print_run(cfg) print_run(cfg)
self.project = cfg.get("wandb_project", "none") self.project = cfg.get("wandb_project", "none")
self.entity = cfg.get("wandb_entity", "none") self.entity = cfg.get("wandb_entity", "none")
if cfg.disable_wandb or self.project == "none" or self.entity == "none": if not cfg.enable_wandb or self.project == "none" or self.entity == "none":
print(colored("Wandb disabled.", "blue", attrs=["bold"])) print(colored("Wandb disabled.", "blue", attrs=["bold"]))
cfg.save_agent = False cfg.save_agent = False
cfg.save_video = False cfg.save_video = False

View File

@@ -9,12 +9,10 @@ def soft_ce(pred, target, cfg):
return -(target * pred).sum(-1, keepdim=True) return -(target * pred).sum(-1, keepdim=True)
def log_std(x, low, dif): def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1) return low + 0.5 * dif * (torch.tanh(x) + 1)
def _gaussian_residual(eps, log_std): def _gaussian_residual(eps, log_std):
return -0.5 * eps.pow(2) - log_std return -0.5 * eps.pow(2) - log_std
@@ -32,7 +30,6 @@ def gaussian_logprob(eps, log_std, size=None):
return _gaussian_logprob(residual) * size return _gaussian_logprob(residual) * size
def _squash(pi): def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
@@ -45,7 +42,6 @@ def squash(mu, pi, log_pi):
return mu, pi, log_pi return mu, pi, log_pi
def symlog(x): def symlog(x):
""" """
Symmetric logarithmic function. Symmetric logarithmic function.
@@ -54,7 +50,6 @@ def symlog(x):
return torch.sign(x) * torch.log(1 + torch.abs(x)) return torch.sign(x) * torch.log(1 + torch.abs(x))
def symexp(x): def symexp(x):
""" """
Symmetric exponential function. Symmetric exponential function.
@@ -90,6 +85,7 @@ def two_hot_inv(x, cfg):
x = torch.sum(x * dreg_bins, dim=-1, keepdim=True) x = torch.sum(x * dreg_bins, dim=-1, keepdim=True)
return symexp(x) return symexp(x)
def gumbel_softmax_sample(p, temperature=1.0, dim=0): def gumbel_softmax_sample(p, temperature=1.0, dim=0):
logits = p.log() logits = p.log()
# Generate Gumbel noise # Generate Gumbel noise

View File

@@ -65,7 +65,7 @@ simnorm_dim: 8
wandb_project: ??? wandb_project: ???
wandb_entity: ??? wandb_entity: ???
wandb_silent: false wandb_silent: false
disable_wandb: true enable_wandb: true
save_csv: true save_csv: true
# misc # misc
@@ -87,6 +87,5 @@ episode_lengths: ???
seed_steps: ??? seed_steps: ???
bin_size: ??? bin_size: ???
# compile # speedups
compile: False compile: False
cudagraphs: False

View File

@@ -1,15 +1,11 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import functools
from common import math from common import math
from common.scale import RunningScale from common.scale import RunningScale
from common.world_model import WorldModel from common.world_model import WorldModel
from tensordict.nn import CudaGraphModule
from tensordict import TensorDict from tensordict import TensorDict
CG_WARMUP = 1000
class TDMPC2(torch.nn.Module): class TDMPC2(torch.nn.Module):
""" """
@@ -21,11 +17,8 @@ class TDMPC2(torch.nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
self.device = torch.device('cuda:0') self.device = torch.device('cuda:0')
self.model = WorldModel(cfg).to(self.device) self.model = WorldModel(cfg).to(self.device)
capturable = True
self.optim = torch.optim.Adam([ self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()}, {'params': self.model._dynamics.parameters()},
@@ -33,8 +26,8 @@ class TDMPC2(torch.nn.Module):
{'params': self.model._Qs.parameters()}, {'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else [] {'params': self.model._task_emb.parameters() if self.cfg.multitask else []
} }
], lr=self.cfg.lr, capturable=capturable) ], lr=self.cfg.lr, capturable=True)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=capturable) self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=True)
self.model.eval() self.model.eval()
self.scale = RunningScale(cfg) self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
@@ -42,43 +35,16 @@ class TDMPC2(torch.nn.Module):
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile: if cfg.compile:
mode = None if cfg.cudagraphs else "reduce-overhead"
print('compiling - update') print('compiling - update')
self._update = torch.compile(self._update, mode=mode) self._update = torch.compile(self._update, mode="reduce-overhead")
if cfg.cudagraphs:
print('cudagraphs - update')
self._update = CudaGraphModule(self._update, warmup=CG_WARMUP)
@property @property
def plan(self): def plan(self):
_plan_val = getattr(self, "_plan_val", None) _plan_val = getattr(self, "_plan_val", None)
if _plan_val is not None: if _plan_val is not None:
return _plan_val return _plan_val
if self.cfg.cudagraphs:
print('cudagraphs - plan')
self._plan_dict = {
(True, True): functools.partial(self._plan, t0=True, eval_mode=True),
(False, True): functools.partial(self._plan, t0=False, eval_mode=True),
(True, False): functools.partial(self._plan, t0=True, eval_mode=False),
(False, False): functools.partial(self._plan, t0=False, eval_mode=False),
}
if self.cfg.compile: if self.cfg.compile:
print('compiling - plan')
mode = None
self._plan_dict = {k: torch.compile(func, mode=mode) for k, func in self._plan_dict.items()}
self._plan_dict = {k: CudaGraphModule(func, warmup=CG_WARMUP) for k, func in self._plan_dict.items()}
def plan(obs, t0=False, eval_mode=False, task=None):
if task is not None:
kwargs = {"task": task}
else:
kwargs = {}
torch.compiler.cudagraph_mark_step_begin()
return self._plan_dict[(t0, eval_mode)](obs=obs, **kwargs)
elif self.cfg.compile:
plan = torch.compile(self._plan, mode="reduce-overhead") plan = torch.compile(self._plan, mode="reduce-overhead")
else: else:
plan = self._plan plan = self._plan
@@ -247,7 +213,6 @@ class TDMPC2(torch.nn.Module):
pi_loss.backward() pi_loss.backward()
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step() self.pi_optim.step()
# For some reason, cudagraph prefers to see the zero grad after step
self.pi_optim.zero_grad(set_to_none=True) self.pi_optim.zero_grad(set_to_none=True)
return pi_loss.detach(), pi_grad_norm return pi_loss.detach(), pi_grad_norm
@@ -269,23 +234,6 @@ class TDMPC2(torch.nn.Module):
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
def update(self, buffer):
"""
Main update function. Corresponds to one iteration of model learning.
Args:
buffer (common.buffer.Buffer): Replay buffer.
Returns:
dict: Dictionary of training statistics.
"""
obs, action, reward, task = buffer.sample()
kwargs = {}
if task is not None:
kwargs["task"] = task
torch.compiler.cudagraph_mark_step_begin()
return self._update(obs, action, reward, **kwargs)
def _update(self, obs, action, reward, task=None): def _update(self, obs, action, reward, task=None):
# Compute targets # Compute targets
with torch.no_grad(): with torch.no_grad():
@@ -314,7 +262,7 @@ class TDMPC2(torch.nn.Module):
reward_loss, value_loss = 0, 0 reward_loss, value_loss = 0, 0
for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))): for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))):
reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t
for q, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)): for _, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)):
value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t
consistency_loss = consistency_loss / self.cfg.horizon consistency_loss = consistency_loss / self.cfg.horizon
@@ -350,3 +298,20 @@ class TDMPC2(torch.nn.Module):
"pi_grad_norm": pi_grad_norm, "pi_grad_norm": pi_grad_norm,
"pi_scale": self.scale.value, "pi_scale": self.scale.value,
}).detach().mean() }).detach().mean()
def update(self, buffer):
"""
Main update function. Corresponds to one iteration of model learning.
Args:
buffer (common.buffer.Buffer): Replay buffer.
Returns:
dict: Dictionary of training statistics.
"""
obs, action, reward, task = buffer.sample()
kwargs = {}
if task is not None:
kwargs["task"] = task
torch.compiler.cudagraph_mark_step_begin()
return self._update(obs, action, reward, **kwargs)

View File

@@ -68,11 +68,11 @@ class OnlineTrainer(Trainer):
train_metrics, done, eval_next = {}, True, False train_metrics, done, eval_next = {}, True, False
while self._step <= self.cfg.steps: while self._step <= self.cfg.steps:
# Evaluate agent periodically # Evaluate agent periodically
if self._step > 0 and self._step % self.cfg.eval_freq == 0: if self._step % self.cfg.eval_freq == 0:
eval_next = True eval_next = True
# Reset environment # Reset environment
if done or (self._step == self.cfg.seed_steps + 1): if done:
if eval_next: if eval_next:
eval_metrics = self.eval() eval_metrics = self.eval()
eval_metrics.update(self.common_metrics()) eval_metrics.update(self.common_metrics())