refactor policy

This commit is contained in:
Nicklas Hansen
2024-12-03 12:22:02 -08:00
parent 0a79c8bd38
commit 32fc2bdf93
3 changed files with 63 additions and 45 deletions

View File

@@ -13,35 +13,32 @@ def log_std(x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1) return low + 0.5 * dif * (torch.tanh(x) + 1)
def _gaussian_residual(eps, log_std): def gaussian_logprob(eps, log_std):
return -0.5 * eps.pow(2) - log_std
def _gaussian_logprob(residual):
log2pi = 1.8378770351409912
return residual - 0.5 * log2pi
def gaussian_logprob(eps, log_std, size=None):
"""Compute Gaussian log probability.""" """Compute Gaussian log probability."""
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) residual = -0.5 * eps.pow(2) - log_std
if size is None: log_prob = residual - 0.9189385175704956
size = eps.shape[-1] return log_prob.sum(-1, keepdim=True)
return _gaussian_logprob(residual) * size
def _squash(pi):
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
def squash(mu, pi, log_pi): def squash(mu, pi, log_pi):
"""Apply squashing function.""" """Apply squashing function."""
mu = torch.tanh(mu) mu = torch.tanh(mu)
pi = torch.tanh(pi) pi = torch.tanh(pi)
log_pi -= _squash(pi).sum(-1, keepdim=True) squashed_pi = torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
log_pi = log_pi - squashed_pi.sum(-1, keepdim=True)
return mu, pi, log_pi return mu, pi, log_pi
def int_to_one_hot(x, num_classes):
"""
Converts an integer tensor to a one-hot tensor.
Supports batched inputs.
"""
one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return one_hot
def symlog(x): def symlog(x):
""" """
Symmetric logarithmic function. Symmetric logarithmic function.

View File

@@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from common import layers, math, init from common import layers, math, init
from tensordict import TensorDict
from tensordict.nn import TensorDictParams from tensordict.nn import TensorDictParams
class WorldModel(nn.Module): class WorldModel(nn.Module):
@@ -131,9 +132,9 @@ class WorldModel(nn.Module):
z = self.task_emb(z, task) z = self.task_emb(z, task)
# Gaussian policy prior # Gaussian policy prior
mu, log_std = self._pi(z).chunk(2, dim=-1) mean, log_std = self._pi(z).chunk(2, dim=-1)
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif) log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
eps = torch.randn_like(mu) eps = torch.randn_like(mean)
if self.cfg.multitask: # Mask out unused action dimensions if self.cfg.multitask: # Mask out unused action dimensions
mu = mu * self._action_masks[task] mu = mu * self._action_masks[task]
@@ -143,11 +144,23 @@ class WorldModel(nn.Module):
else: # No masking else: # No masking
action_dims = None action_dims = None
log_pi = math.gaussian_logprob(eps, log_std, size=action_dims) log_prob = math.gaussian_logprob(eps, log_std)
pi = mu + eps * log_std.exp()
mu, pi, log_pi = math.squash(mu, pi, log_pi)
return mu, pi, log_pi, log_std # Scale log probability by action dimensions
size = eps.shape[-1] if action_dims is None else action_dims
scaled_log_prob = log_prob * size
# Reparameterization trick
action = mean + eps * log_std.exp()
mean, action, log_prob = math.squash(mean, action, log_prob)
info = TensorDict({
"mean": mean,
"log_std": log_std,
"entropy": -log_prob,
"entropy_scale": self.cfg.entropy_coef * scaled_log_prob / log_prob,
})
return action, info
def Q(self, z, a, task, return_type='min', target=False, detach=False): def Q(self, z, a, task, return_type='min', target=False, detach=False):
""" """

View File

@@ -103,11 +103,12 @@ class TDMPC2(torch.nn.Module):
if task is not None: if task is not None:
task = torch.tensor([task], device=self.device) task = torch.tensor([task], device=self.device)
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) return self.plan(obs, t0=t0, eval_mode=eval_mode, task=task).cpu()
else:
z = self.model.encode(obs, task) z = self.model.encode(obs, task)
a = self.model.pi(z, task)[int(not eval_mode)][0] action, info = self.model.pi(z, task)
return a.cpu() if eval_mode:
action = info["mean"]
return action[0].cpu()
@torch.no_grad() @torch.no_grad()
def _estimate_value(self, z, actions, task): def _estimate_value(self, z, actions, task):
@@ -119,7 +120,8 @@ class TDMPC2(torch.nn.Module):
G = G + discount * reward G = G + discount * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update discount = discount * discount_update
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') action, _ = self.model.pi(z, task)
return G + discount * self.model.Q(z, action, task, return_type='avg')
@torch.no_grad() @torch.no_grad()
def _plan(self, obs, t0=False, eval_mode=False, task=None): def _plan(self, obs, t0=False, eval_mode=False, task=None):
@@ -141,9 +143,9 @@ class TDMPC2(torch.nn.Module):
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1) _z = z.repeat(self.cfg.num_pi_trajs, 1)
for t in range(self.cfg.horizon-1): for t in range(self.cfg.horizon-1):
pi_actions[t] = self.model.pi(_z, task)[1] pi_actions[t], _ = self.model.pi(_z, task)
_z = self.model.next(_z, pi_actions[t], task) _z = self.model.next(_z, pi_actions[t], task)
pi_actions[-1] = self.model.pi(_z, task)[1] pi_actions[-1], _ = self.model.pi(_z, task)
# Initialize state and parameters # Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1) z = z.repeat(self.cfg.num_samples, 1)
@@ -202,20 +204,27 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
float: Loss of the policy update. float: Loss of the policy update.
""" """
_, pis, log_pis, _ = self.model.pi(zs, task) action, info = self.model.pi(zs, task)
qs = self.model.Q(zs, pis, task, return_type='avg', detach=True) qs = self.model.Q(zs, action, task, return_type='avg', detach=True)
self.scale.update(qs[0]) self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
# Loss is a weighted sum of Q-values # Loss is a weighted sum of Q-values
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() pi_loss = (-(info["entropy_scale"] * info["entropy"] + qs).mean(dim=(1,2)) * rho).mean()
pi_loss.backward() pi_loss.backward()
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step() self.pi_optim.step()
self.pi_optim.zero_grad(set_to_none=True) self.pi_optim.zero_grad(set_to_none=True)
return pi_loss.detach(), pi_grad_norm info = TensorDict({
"pi_loss": pi_loss,
"pi_grad_norm": pi_grad_norm,
"pi_entropy": info["entropy"],
"pi_entropy_scale": info["entropy_scale"],
"pi_scale": self.scale.value,
})
return info
@torch.no_grad() @torch.no_grad()
def _td_target(self, next_z, reward, task): def _td_target(self, next_z, reward, task):
@@ -230,9 +239,9 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
torch.Tensor: TD-target. torch.Tensor: TD-target.
""" """
pi = self.model.pi(next_z, task)[1] action, _ = self.model.pi(next_z, task)
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) return reward + discount * self.model.Q(next_z, action, task, return_type='min', target=True)
def _update(self, obs, action, reward, task=None): def _update(self, obs, action, reward, task=None):
# Compute targets # Compute targets
@@ -281,23 +290,22 @@ class TDMPC2(torch.nn.Module):
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
# Update policy # Update policy
pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task) pi_info = self.update_pi(zs.detach(), task)
# Update target Q-functions # Update target Q-functions
self.model.soft_update_target_Q() self.model.soft_update_target_Q()
# Return training statistics # Return training statistics
self.model.eval() self.model.eval()
return TensorDict({ info = TensorDict({
"consistency_loss": consistency_loss, "consistency_loss": consistency_loss,
"reward_loss": reward_loss, "reward_loss": reward_loss,
"value_loss": value_loss, "value_loss": value_loss,
"pi_loss": pi_loss,
"total_loss": total_loss, "total_loss": total_loss,
"grad_norm": grad_norm, "grad_norm": grad_norm,
"pi_grad_norm": pi_grad_norm, })
"pi_scale": self.scale.value, info.update(pi_info)
}).detach().mean() return info.detach().mean()
def update(self, buffer): def update(self, buffer):
""" """