fix eval index + clean up

This commit is contained in:
Nicklas Hansen
2024-10-21 14:49:21 -07:00
parent 970792e2b6
commit 836547d76f
6 changed files with 39 additions and 68 deletions

View File

@@ -12,6 +12,15 @@ Official implementation of
----
**Note: the `speedups` branch is experimental and may contain bugs. Please use the `main` branch for the latest stable release.**
Expect **3-8x** faster wall-time (depending on hardware and task) compared to `main` branch. A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. We are not responsible for any issues that may arise from using our repository.
Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to this branch!
----
## Overview
TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across **104** continuous control tasks spanning multiple domains, with a *single* set of hyperparameters (*right*). We further demonstrate the scalability of TD-MPC**2** by training a single 317M parameter agent to perform **80** tasks across multiple domains, embodiments, and action spaces (*left*).

View File

@@ -2,9 +2,11 @@ import dataclasses
import os
import datetime
import re
import numpy as np
import pandas as pd
from termcolor import colored
from common import TASK_SET
@@ -115,7 +117,7 @@ class Logger:
print_run(cfg)
self.project = cfg.get("wandb_project", "none")
self.entity = cfg.get("wandb_entity", "none")
if cfg.disable_wandb or self.project == "none" or self.entity == "none":
if not cfg.enable_wandb or self.project == "none" or self.entity == "none":
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
cfg.save_agent = False
cfg.save_video = False

View File

@@ -9,12 +9,10 @@ def soft_ce(pred, target, cfg):
return -(target * pred).sum(-1, keepdim=True)
def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1)
def _gaussian_residual(eps, log_std):
return -0.5 * eps.pow(2) - log_std
@@ -32,7 +30,6 @@ def gaussian_logprob(eps, log_std, size=None):
return _gaussian_logprob(residual) * size
def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
@@ -45,7 +42,6 @@ def squash(mu, pi, log_pi):
return mu, pi, log_pi
def symlog(x):
"""
Symmetric logarithmic function.
@@ -54,7 +50,6 @@ def symlog(x):
return torch.sign(x) * torch.log(1 + torch.abs(x))
def symexp(x):
"""
Symmetric exponential function.
@@ -90,6 +85,7 @@ def two_hot_inv(x, cfg):
x = torch.sum(x * dreg_bins, dim=-1, keepdim=True)
return symexp(x)
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
logits = p.log()
# Generate Gumbel noise

View File

@@ -65,7 +65,7 @@ simnorm_dim: 8
wandb_project: ???
wandb_entity: ???
wandb_silent: false
disable_wandb: true
enable_wandb: true
save_csv: true
# misc
@@ -87,6 +87,5 @@ episode_lengths: ???
seed_steps: ???
bin_size: ???
# compile
# speedups
compile: False
cudagraphs: False

View File

@@ -1,15 +1,11 @@
import torch
import torch.nn.functional as F
import functools
from common import math
from common.scale import RunningScale
from common.world_model import WorldModel
from tensordict.nn import CudaGraphModule
from tensordict import TensorDict
CG_WARMUP = 1000
class TDMPC2(torch.nn.Module):
"""
@@ -21,11 +17,8 @@ class TDMPC2(torch.nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.device = torch.device('cuda:0')
self.model = WorldModel(cfg).to(self.device)
capturable = True
self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()},
@@ -33,8 +26,8 @@ class TDMPC2(torch.nn.Module):
{'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
}
], lr=self.cfg.lr, capturable=capturable)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=capturable)
], lr=self.cfg.lr, capturable=True)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=True)
self.model.eval()
self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
@@ -42,43 +35,16 @@ class TDMPC2(torch.nn.Module):
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile:
mode = None if cfg.cudagraphs else "reduce-overhead"
print('compiling - update')
self._update = torch.compile(self._update, mode=mode)
if cfg.cudagraphs:
print('cudagraphs - update')
self._update = CudaGraphModule(self._update, warmup=CG_WARMUP)
self._update = torch.compile(self._update, mode="reduce-overhead")
@property
def plan(self):
_plan_val = getattr(self, "_plan_val", None)
if _plan_val is not None:
return _plan_val
if self.cfg.cudagraphs:
print('cudagraphs - plan')
self._plan_dict = {
(True, True): functools.partial(self._plan, t0=True, eval_mode=True),
(False, True): functools.partial(self._plan, t0=False, eval_mode=True),
(True, False): functools.partial(self._plan, t0=True, eval_mode=False),
(False, False): functools.partial(self._plan, t0=False, eval_mode=False),
}
if self.cfg.compile:
print('compiling - plan')
mode = None
self._plan_dict = {k: torch.compile(func, mode=mode) for k, func in self._plan_dict.items()}
self._plan_dict = {k: CudaGraphModule(func, warmup=CG_WARMUP) for k, func in self._plan_dict.items()}
def plan(obs, t0=False, eval_mode=False, task=None):
if task is not None:
kwargs = {"task": task}
else:
kwargs = {}
torch.compiler.cudagraph_mark_step_begin()
return self._plan_dict[(t0, eval_mode)](obs=obs, **kwargs)
elif self.cfg.compile:
if self.cfg.compile:
plan = torch.compile(self._plan, mode="reduce-overhead")
else:
plan = self._plan
@@ -247,7 +213,6 @@ class TDMPC2(torch.nn.Module):
pi_loss.backward()
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step()
# For some reason, cudagraph prefers to see the zero grad after step
self.pi_optim.zero_grad(set_to_none=True)
return pi_loss.detach(), pi_grad_norm
@@ -269,23 +234,6 @@ class TDMPC2(torch.nn.Module):
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
def update(self, buffer):
"""
Main update function. Corresponds to one iteration of model learning.
Args:
buffer (common.buffer.Buffer): Replay buffer.
Returns:
dict: Dictionary of training statistics.
"""
obs, action, reward, task = buffer.sample()
kwargs = {}
if task is not None:
kwargs["task"] = task
torch.compiler.cudagraph_mark_step_begin()
return self._update(obs, action, reward, **kwargs)
def _update(self, obs, action, reward, task=None):
# Compute targets
with torch.no_grad():
@@ -314,7 +262,7 @@ class TDMPC2(torch.nn.Module):
reward_loss, value_loss = 0, 0
for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))):
reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t
for q, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)):
for _, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)):
value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t
consistency_loss = consistency_loss / self.cfg.horizon
@@ -350,3 +298,20 @@ class TDMPC2(torch.nn.Module):
"pi_grad_norm": pi_grad_norm,
"pi_scale": self.scale.value,
}).detach().mean()
def update(self, buffer):
"""
Main update function. Corresponds to one iteration of model learning.
Args:
buffer (common.buffer.Buffer): Replay buffer.
Returns:
dict: Dictionary of training statistics.
"""
obs, action, reward, task = buffer.sample()
kwargs = {}
if task is not None:
kwargs["task"] = task
torch.compiler.cudagraph_mark_step_begin()
return self._update(obs, action, reward, **kwargs)

View File

@@ -68,11 +68,11 @@ class OnlineTrainer(Trainer):
train_metrics, done, eval_next = {}, True, False
while self._step <= self.cfg.steps:
# Evaluate agent periodically
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
if self._step % self.cfg.eval_freq == 0:
eval_next = True
# Reset environment
if done or (self._step == self.cfg.seed_steps + 1):
if done:
if eval_next:
eval_metrics = self.eval()
eval_metrics.update(self.common_metrics())