refactor policy
This commit is contained in:
@@ -13,35 +13,32 @@ def log_std(x, low, dif):
|
|||||||
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||||
|
|
||||||
|
|
||||||
def _gaussian_residual(eps, log_std):
|
def gaussian_logprob(eps, log_std):
|
||||||
return -0.5 * eps.pow(2) - log_std
|
|
||||||
|
|
||||||
|
|
||||||
def _gaussian_logprob(residual):
|
|
||||||
log2pi = 1.8378770351409912
|
|
||||||
return residual - 0.5 * log2pi
|
|
||||||
|
|
||||||
|
|
||||||
def gaussian_logprob(eps, log_std, size=None):
|
|
||||||
"""Compute Gaussian log probability."""
|
"""Compute Gaussian log probability."""
|
||||||
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
residual = -0.5 * eps.pow(2) - log_std
|
||||||
if size is None:
|
log_prob = residual - 0.9189385175704956
|
||||||
size = eps.shape[-1]
|
return log_prob.sum(-1, keepdim=True)
|
||||||
return _gaussian_logprob(residual) * size
|
|
||||||
|
|
||||||
|
|
||||||
def _squash(pi):
|
|
||||||
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
def squash(mu, pi, log_pi):
|
def squash(mu, pi, log_pi):
|
||||||
"""Apply squashing function."""
|
"""Apply squashing function."""
|
||||||
mu = torch.tanh(mu)
|
mu = torch.tanh(mu)
|
||||||
pi = torch.tanh(pi)
|
pi = torch.tanh(pi)
|
||||||
log_pi -= _squash(pi).sum(-1, keepdim=True)
|
squashed_pi = torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||||
|
log_pi = log_pi - squashed_pi.sum(-1, keepdim=True)
|
||||||
return mu, pi, log_pi
|
return mu, pi, log_pi
|
||||||
|
|
||||||
|
|
||||||
|
def int_to_one_hot(x, num_classes):
|
||||||
|
"""
|
||||||
|
Converts an integer tensor to a one-hot tensor.
|
||||||
|
Supports batched inputs.
|
||||||
|
"""
|
||||||
|
one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
|
||||||
|
one_hot.scatter_(-1, x.unsqueeze(-1), 1)
|
||||||
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
def symlog(x):
|
def symlog(x):
|
||||||
"""
|
"""
|
||||||
Symmetric logarithmic function.
|
Symmetric logarithmic function.
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from common import layers, math, init
|
from common import layers, math, init
|
||||||
|
from tensordict import TensorDict
|
||||||
from tensordict.nn import TensorDictParams
|
from tensordict.nn import TensorDictParams
|
||||||
|
|
||||||
class WorldModel(nn.Module):
|
class WorldModel(nn.Module):
|
||||||
@@ -131,9 +132,9 @@ class WorldModel(nn.Module):
|
|||||||
z = self.task_emb(z, task)
|
z = self.task_emb(z, task)
|
||||||
|
|
||||||
# Gaussian policy prior
|
# Gaussian policy prior
|
||||||
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
mean, log_std = self._pi(z).chunk(2, dim=-1)
|
||||||
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
|
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
|
||||||
eps = torch.randn_like(mu)
|
eps = torch.randn_like(mean)
|
||||||
|
|
||||||
if self.cfg.multitask: # Mask out unused action dimensions
|
if self.cfg.multitask: # Mask out unused action dimensions
|
||||||
mu = mu * self._action_masks[task]
|
mu = mu * self._action_masks[task]
|
||||||
@@ -143,11 +144,23 @@ class WorldModel(nn.Module):
|
|||||||
else: # No masking
|
else: # No masking
|
||||||
action_dims = None
|
action_dims = None
|
||||||
|
|
||||||
log_pi = math.gaussian_logprob(eps, log_std, size=action_dims)
|
log_prob = math.gaussian_logprob(eps, log_std)
|
||||||
pi = mu + eps * log_std.exp()
|
|
||||||
mu, pi, log_pi = math.squash(mu, pi, log_pi)
|
|
||||||
|
|
||||||
return mu, pi, log_pi, log_std
|
# Scale log probability by action dimensions
|
||||||
|
size = eps.shape[-1] if action_dims is None else action_dims
|
||||||
|
scaled_log_prob = log_prob * size
|
||||||
|
|
||||||
|
# Reparameterization trick
|
||||||
|
action = mean + eps * log_std.exp()
|
||||||
|
mean, action, log_prob = math.squash(mean, action, log_prob)
|
||||||
|
|
||||||
|
info = TensorDict({
|
||||||
|
"mean": mean,
|
||||||
|
"log_std": log_std,
|
||||||
|
"entropy": -log_prob,
|
||||||
|
"entropy_scale": self.cfg.entropy_coef * scaled_log_prob / log_prob,
|
||||||
|
})
|
||||||
|
return action, info
|
||||||
|
|
||||||
def Q(self, z, a, task, return_type='min', target=False, detach=False):
|
def Q(self, z, a, task, return_type='min', target=False, detach=False):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -103,11 +103,12 @@ class TDMPC2(torch.nn.Module):
|
|||||||
if task is not None:
|
if task is not None:
|
||||||
task = torch.tensor([task], device=self.device)
|
task = torch.tensor([task], device=self.device)
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
|
return self.plan(obs, t0=t0, eval_mode=eval_mode, task=task).cpu()
|
||||||
else:
|
|
||||||
z = self.model.encode(obs, task)
|
z = self.model.encode(obs, task)
|
||||||
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
action, info = self.model.pi(z, task)
|
||||||
return a.cpu()
|
if eval_mode:
|
||||||
|
action = info["mean"]
|
||||||
|
return action[0].cpu()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _estimate_value(self, z, actions, task):
|
def _estimate_value(self, z, actions, task):
|
||||||
@@ -119,7 +120,8 @@ class TDMPC2(torch.nn.Module):
|
|||||||
G = G + discount * reward
|
G = G + discount * reward
|
||||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||||
discount = discount * discount_update
|
discount = discount * discount_update
|
||||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
action, _ = self.model.pi(z, task)
|
||||||
|
return G + discount * self.model.Q(z, action, task, return_type='avg')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||||
@@ -141,9 +143,9 @@ class TDMPC2(torch.nn.Module):
|
|||||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
||||||
for t in range(self.cfg.horizon-1):
|
for t in range(self.cfg.horizon-1):
|
||||||
pi_actions[t] = self.model.pi(_z, task)[1]
|
pi_actions[t], _ = self.model.pi(_z, task)
|
||||||
_z = self.model.next(_z, pi_actions[t], task)
|
_z = self.model.next(_z, pi_actions[t], task)
|
||||||
pi_actions[-1] = self.model.pi(_z, task)[1]
|
pi_actions[-1], _ = self.model.pi(_z, task)
|
||||||
|
|
||||||
# Initialize state and parameters
|
# Initialize state and parameters
|
||||||
z = z.repeat(self.cfg.num_samples, 1)
|
z = z.repeat(self.cfg.num_samples, 1)
|
||||||
@@ -202,20 +204,27 @@ class TDMPC2(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
float: Loss of the policy update.
|
float: Loss of the policy update.
|
||||||
"""
|
"""
|
||||||
_, pis, log_pis, _ = self.model.pi(zs, task)
|
action, info = self.model.pi(zs, task)
|
||||||
qs = self.model.Q(zs, pis, task, return_type='avg', detach=True)
|
qs = self.model.Q(zs, action, task, return_type='avg', detach=True)
|
||||||
self.scale.update(qs[0])
|
self.scale.update(qs[0])
|
||||||
qs = self.scale(qs)
|
qs = self.scale(qs)
|
||||||
|
|
||||||
# Loss is a weighted sum of Q-values
|
# Loss is a weighted sum of Q-values
|
||||||
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||||
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
|
pi_loss = (-(info["entropy_scale"] * info["entropy"] + qs).mean(dim=(1,2)) * rho).mean()
|
||||||
pi_loss.backward()
|
pi_loss.backward()
|
||||||
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
||||||
self.pi_optim.step()
|
self.pi_optim.step()
|
||||||
self.pi_optim.zero_grad(set_to_none=True)
|
self.pi_optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
return pi_loss.detach(), pi_grad_norm
|
info = TensorDict({
|
||||||
|
"pi_loss": pi_loss,
|
||||||
|
"pi_grad_norm": pi_grad_norm,
|
||||||
|
"pi_entropy": info["entropy"],
|
||||||
|
"pi_entropy_scale": info["entropy_scale"],
|
||||||
|
"pi_scale": self.scale.value,
|
||||||
|
})
|
||||||
|
return info
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _td_target(self, next_z, reward, task):
|
def _td_target(self, next_z, reward, task):
|
||||||
@@ -230,9 +239,9 @@ class TDMPC2(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: TD-target.
|
torch.Tensor: TD-target.
|
||||||
"""
|
"""
|
||||||
pi = self.model.pi(next_z, task)[1]
|
action, _ = self.model.pi(next_z, task)
|
||||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||||
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
return reward + discount * self.model.Q(next_z, action, task, return_type='min', target=True)
|
||||||
|
|
||||||
def _update(self, obs, action, reward, task=None):
|
def _update(self, obs, action, reward, task=None):
|
||||||
# Compute targets
|
# Compute targets
|
||||||
@@ -281,23 +290,22 @@ class TDMPC2(torch.nn.Module):
|
|||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# Update policy
|
# Update policy
|
||||||
pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task)
|
pi_info = self.update_pi(zs.detach(), task)
|
||||||
|
|
||||||
# Update target Q-functions
|
# Update target Q-functions
|
||||||
self.model.soft_update_target_Q()
|
self.model.soft_update_target_Q()
|
||||||
|
|
||||||
# Return training statistics
|
# Return training statistics
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
return TensorDict({
|
info = TensorDict({
|
||||||
"consistency_loss": consistency_loss,
|
"consistency_loss": consistency_loss,
|
||||||
"reward_loss": reward_loss,
|
"reward_loss": reward_loss,
|
||||||
"value_loss": value_loss,
|
"value_loss": value_loss,
|
||||||
"pi_loss": pi_loss,
|
|
||||||
"total_loss": total_loss,
|
"total_loss": total_loss,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
"pi_grad_norm": pi_grad_norm,
|
})
|
||||||
"pi_scale": self.scale.value,
|
info.update(pi_info)
|
||||||
}).detach().mean()
|
return info.detach().mean()
|
||||||
|
|
||||||
def update(self, buffer):
|
def update(self, buffer):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user