fix eval index + clean up
This commit is contained in:
@@ -12,6 +12,15 @@ Official implementation of
|
||||
|
||||
----
|
||||
|
||||
**Note: the `speedups` branch is experimental and may contain bugs. Please use the `main` branch for the latest stable release.**
|
||||
|
||||
Expect **3-8x** faster wall-time (depending on hardware and task) compared to `main` branch. A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. We are not responsible for any issues that may arise from using our repository.
|
||||
|
||||
Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to this branch!
|
||||
|
||||
----
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
TD-MPC**2** is a scalable, robust model-based reinforcement learning algorithm. It compares favorably to existing model-free and model-based methods across **104** continuous control tasks spanning multiple domains, with a *single* set of hyperparameters (*right*). We further demonstrate the scalability of TD-MPC**2** by training a single 317M parameter agent to perform **80** tasks across multiple domains, embodiments, and action spaces (*left*).
|
||||
|
||||
@@ -2,9 +2,11 @@ import dataclasses
|
||||
import os
|
||||
import datetime
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from termcolor import colored
|
||||
|
||||
from common import TASK_SET
|
||||
|
||||
|
||||
@@ -115,7 +117,7 @@ class Logger:
|
||||
print_run(cfg)
|
||||
self.project = cfg.get("wandb_project", "none")
|
||||
self.entity = cfg.get("wandb_entity", "none")
|
||||
if cfg.disable_wandb or self.project == "none" or self.entity == "none":
|
||||
if not cfg.enable_wandb or self.project == "none" or self.entity == "none":
|
||||
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
|
||||
cfg.save_agent = False
|
||||
cfg.save_video = False
|
||||
|
||||
@@ -9,12 +9,10 @@ def soft_ce(pred, target, cfg):
|
||||
return -(target * pred).sum(-1, keepdim=True)
|
||||
|
||||
|
||||
|
||||
def log_std(x, low, dif):
|
||||
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||
|
||||
|
||||
|
||||
def _gaussian_residual(eps, log_std):
|
||||
return -0.5 * eps.pow(2) - log_std
|
||||
|
||||
@@ -32,7 +30,6 @@ def gaussian_logprob(eps, log_std, size=None):
|
||||
return _gaussian_logprob(residual) * size
|
||||
|
||||
|
||||
|
||||
def _squash(pi):
|
||||
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||
|
||||
@@ -45,7 +42,6 @@ def squash(mu, pi, log_pi):
|
||||
return mu, pi, log_pi
|
||||
|
||||
|
||||
|
||||
def symlog(x):
|
||||
"""
|
||||
Symmetric logarithmic function.
|
||||
@@ -54,7 +50,6 @@ def symlog(x):
|
||||
return torch.sign(x) * torch.log(1 + torch.abs(x))
|
||||
|
||||
|
||||
|
||||
def symexp(x):
|
||||
"""
|
||||
Symmetric exponential function.
|
||||
@@ -90,6 +85,7 @@ def two_hot_inv(x, cfg):
|
||||
x = torch.sum(x * dreg_bins, dim=-1, keepdim=True)
|
||||
return symexp(x)
|
||||
|
||||
|
||||
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
|
||||
logits = p.log()
|
||||
# Generate Gumbel noise
|
||||
|
||||
@@ -65,7 +65,7 @@ simnorm_dim: 8
|
||||
wandb_project: ???
|
||||
wandb_entity: ???
|
||||
wandb_silent: false
|
||||
disable_wandb: true
|
||||
enable_wandb: true
|
||||
save_csv: true
|
||||
|
||||
# misc
|
||||
@@ -87,6 +87,5 @@ episode_lengths: ???
|
||||
seed_steps: ???
|
||||
bin_size: ???
|
||||
|
||||
# compile
|
||||
# speedups
|
||||
compile: False
|
||||
cudagraphs: False
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import functools
|
||||
|
||||
from common import math
|
||||
from common.scale import RunningScale
|
||||
from common.world_model import WorldModel
|
||||
from tensordict.nn import CudaGraphModule
|
||||
from tensordict import TensorDict
|
||||
|
||||
CG_WARMUP = 1000
|
||||
|
||||
|
||||
class TDMPC2(torch.nn.Module):
|
||||
"""
|
||||
@@ -21,11 +17,8 @@ class TDMPC2(torch.nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.device = torch.device('cuda:0')
|
||||
|
||||
self.model = WorldModel(cfg).to(self.device)
|
||||
capturable = True
|
||||
self.optim = torch.optim.Adam([
|
||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||
{'params': self.model._dynamics.parameters()},
|
||||
@@ -33,8 +26,8 @@ class TDMPC2(torch.nn.Module):
|
||||
{'params': self.model._Qs.parameters()},
|
||||
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
|
||||
}
|
||||
], lr=self.cfg.lr, capturable=capturable)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=capturable)
|
||||
], lr=self.cfg.lr, capturable=True)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr, eps=1e-5, capturable=True)
|
||||
self.model.eval()
|
||||
self.scale = RunningScale(cfg)
|
||||
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
|
||||
@@ -42,43 +35,16 @@ class TDMPC2(torch.nn.Module):
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||
|
||||
if cfg.compile:
|
||||
mode = None if cfg.cudagraphs else "reduce-overhead"
|
||||
print('compiling - update')
|
||||
self._update = torch.compile(self._update, mode=mode)
|
||||
|
||||
if cfg.cudagraphs:
|
||||
print('cudagraphs - update')
|
||||
self._update = CudaGraphModule(self._update, warmup=CG_WARMUP)
|
||||
self._update = torch.compile(self._update, mode="reduce-overhead")
|
||||
|
||||
@property
|
||||
def plan(self):
|
||||
_plan_val = getattr(self, "_plan_val", None)
|
||||
if _plan_val is not None:
|
||||
return _plan_val
|
||||
if self.cfg.cudagraphs:
|
||||
print('cudagraphs - plan')
|
||||
self._plan_dict = {
|
||||
(True, True): functools.partial(self._plan, t0=True, eval_mode=True),
|
||||
(False, True): functools.partial(self._plan, t0=False, eval_mode=True),
|
||||
(True, False): functools.partial(self._plan, t0=True, eval_mode=False),
|
||||
(False, False): functools.partial(self._plan, t0=False, eval_mode=False),
|
||||
}
|
||||
if self.cfg.compile:
|
||||
print('compiling - plan')
|
||||
mode = None
|
||||
self._plan_dict = {k: torch.compile(func, mode=mode) for k, func in self._plan_dict.items()}
|
||||
|
||||
self._plan_dict = {k: CudaGraphModule(func, warmup=CG_WARMUP) for k, func in self._plan_dict.items()}
|
||||
def plan(obs, t0=False, eval_mode=False, task=None):
|
||||
if task is not None:
|
||||
kwargs = {"task": task}
|
||||
else:
|
||||
kwargs = {}
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
return self._plan_dict[(t0, eval_mode)](obs=obs, **kwargs)
|
||||
elif self.cfg.compile:
|
||||
if self.cfg.compile:
|
||||
plan = torch.compile(self._plan, mode="reduce-overhead")
|
||||
else:
|
||||
plan = self._plan
|
||||
@@ -247,7 +213,6 @@ class TDMPC2(torch.nn.Module):
|
||||
pi_loss.backward()
|
||||
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
||||
self.pi_optim.step()
|
||||
# For some reason, cudagraph prefers to see the zero grad after step
|
||||
self.pi_optim.zero_grad(set_to_none=True)
|
||||
|
||||
return pi_loss.detach(), pi_grad_norm
|
||||
@@ -269,23 +234,6 @@ class TDMPC2(torch.nn.Module):
|
||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
||||
|
||||
def update(self, buffer):
|
||||
"""
|
||||
Main update function. Corresponds to one iteration of model learning.
|
||||
|
||||
Args:
|
||||
buffer (common.buffer.Buffer): Replay buffer.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of training statistics.
|
||||
"""
|
||||
obs, action, reward, task = buffer.sample()
|
||||
kwargs = {}
|
||||
if task is not None:
|
||||
kwargs["task"] = task
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
return self._update(obs, action, reward, **kwargs)
|
||||
|
||||
def _update(self, obs, action, reward, task=None):
|
||||
# Compute targets
|
||||
with torch.no_grad():
|
||||
@@ -314,7 +262,7 @@ class TDMPC2(torch.nn.Module):
|
||||
reward_loss, value_loss = 0, 0
|
||||
for t, (rew_pred_unbind, rew_unbind, td_targets_unbind, qs_unbind) in enumerate(zip(reward_preds.unbind(0), reward.unbind(0), td_targets.unbind(0), qs.unbind(1))):
|
||||
reward_loss = reward_loss + math.soft_ce(rew_pred_unbind, rew_unbind, self.cfg).mean() * self.cfg.rho**t
|
||||
for q, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)):
|
||||
for _, qs_unbind_unbind in enumerate(qs_unbind.unbind(0)):
|
||||
value_loss = value_loss + math.soft_ce(qs_unbind_unbind, td_targets_unbind, self.cfg).mean() * self.cfg.rho**t
|
||||
|
||||
consistency_loss = consistency_loss / self.cfg.horizon
|
||||
@@ -350,3 +298,20 @@ class TDMPC2(torch.nn.Module):
|
||||
"pi_grad_norm": pi_grad_norm,
|
||||
"pi_scale": self.scale.value,
|
||||
}).detach().mean()
|
||||
|
||||
def update(self, buffer):
|
||||
"""
|
||||
Main update function. Corresponds to one iteration of model learning.
|
||||
|
||||
Args:
|
||||
buffer (common.buffer.Buffer): Replay buffer.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary of training statistics.
|
||||
"""
|
||||
obs, action, reward, task = buffer.sample()
|
||||
kwargs = {}
|
||||
if task is not None:
|
||||
kwargs["task"] = task
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
return self._update(obs, action, reward, **kwargs)
|
||||
|
||||
@@ -68,11 +68,11 @@ class OnlineTrainer(Trainer):
|
||||
train_metrics, done, eval_next = {}, True, False
|
||||
while self._step <= self.cfg.steps:
|
||||
# Evaluate agent periodically
|
||||
if self._step > 0 and self._step % self.cfg.eval_freq == 0:
|
||||
if self._step % self.cfg.eval_freq == 0:
|
||||
eval_next = True
|
||||
|
||||
# Reset environment
|
||||
if done or (self._step == self.cfg.seed_steps + 1):
|
||||
if done:
|
||||
if eval_next:
|
||||
eval_metrics = self.eval()
|
||||
eval_metrics.update(self.common_metrics())
|
||||
|
||||
Reference in New Issue
Block a user