diff --git a/tdmpc2/common/math.py b/tdmpc2/common/math.py index 5ac92ad..cc37800 100644 --- a/tdmpc2/common/math.py +++ b/tdmpc2/common/math.py @@ -13,35 +13,32 @@ def log_std(x, low, dif): return low + 0.5 * dif * (torch.tanh(x) + 1) -def _gaussian_residual(eps, log_std): - return -0.5 * eps.pow(2) - log_std - - -def _gaussian_logprob(residual): - log2pi = 1.8378770351409912 - return residual - 0.5 * log2pi - - -def gaussian_logprob(eps, log_std, size=None): +def gaussian_logprob(eps, log_std): """Compute Gaussian log probability.""" - residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True) - if size is None: - size = eps.shape[-1] - return _gaussian_logprob(residual) * size - - -def _squash(pi): - return torch.log(F.relu(1 - pi.pow(2)) + 1e-6) + residual = -0.5 * eps.pow(2) - log_std + log_prob = residual - 0.9189385175704956 + return log_prob.sum(-1, keepdim=True) def squash(mu, pi, log_pi): """Apply squashing function.""" mu = torch.tanh(mu) pi = torch.tanh(pi) - log_pi -= _squash(pi).sum(-1, keepdim=True) + squashed_pi = torch.log(F.relu(1 - pi.pow(2)) + 1e-6) + log_pi = log_pi - squashed_pi.sum(-1, keepdim=True) return mu, pi, log_pi +def int_to_one_hot(x, num_classes): + """ + Converts an integer tensor to a one-hot tensor. + Supports batched inputs. + """ + one_hot = torch.zeros(*x.shape, num_classes, device=x.device) + one_hot.scatter_(-1, x.unsqueeze(-1), 1) + return one_hot + + def symlog(x): """ Symmetric logarithmic function. diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index eb9633d..8222b99 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn from common import layers, math, init +from tensordict import TensorDict from tensordict.nn import TensorDictParams class WorldModel(nn.Module): @@ -131,9 +132,9 @@ class WorldModel(nn.Module): z = self.task_emb(z, task) # Gaussian policy prior - mu, log_std = self._pi(z).chunk(2, dim=-1) + mean, log_std = self._pi(z).chunk(2, dim=-1) log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif) - eps = torch.randn_like(mu) + eps = torch.randn_like(mean) if self.cfg.multitask: # Mask out unused action dimensions mu = mu * self._action_masks[task] @@ -143,11 +144,23 @@ class WorldModel(nn.Module): else: # No masking action_dims = None - log_pi = math.gaussian_logprob(eps, log_std, size=action_dims) - pi = mu + eps * log_std.exp() - mu, pi, log_pi = math.squash(mu, pi, log_pi) + log_prob = math.gaussian_logprob(eps, log_std) - return mu, pi, log_pi, log_std + # Scale log probability by action dimensions + size = eps.shape[-1] if action_dims is None else action_dims + scaled_log_prob = log_prob * size + + # Reparameterization trick + action = mean + eps * log_std.exp() + mean, action, log_prob = math.squash(mean, action, log_prob) + + info = TensorDict({ + "mean": mean, + "log_std": log_std, + "entropy": -log_prob, + "entropy_scale": self.cfg.entropy_coef * scaled_log_prob / log_prob, + }) + return action, info def Q(self, z, a, task, return_type='min', target=False, detach=False): """ diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index e4d8ec2..12c2832 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -103,11 +103,12 @@ class TDMPC2(torch.nn.Module): if task is not None: task = torch.tensor([task], device=self.device) if self.cfg.mpc: - a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) - else: - z = self.model.encode(obs, task) - a = self.model.pi(z, task)[int(not eval_mode)][0] - return a.cpu() + return self.plan(obs, t0=t0, eval_mode=eval_mode, task=task).cpu() + z = self.model.encode(obs, task) + action, info = self.model.pi(z, task) + if eval_mode: + action = info["mean"] + return action[0].cpu() @torch.no_grad() def _estimate_value(self, z, actions, task): @@ -119,7 +120,8 @@ class TDMPC2(torch.nn.Module): G = G + discount * reward discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount = discount * discount_update - return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') + action, _ = self.model.pi(z, task) + return G + discount * self.model.Q(z, action, task, return_type='avg') @torch.no_grad() def _plan(self, obs, t0=False, eval_mode=False, task=None): @@ -141,9 +143,9 @@ class TDMPC2(torch.nn.Module): pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) _z = z.repeat(self.cfg.num_pi_trajs, 1) for t in range(self.cfg.horizon-1): - pi_actions[t] = self.model.pi(_z, task)[1] + pi_actions[t], _ = self.model.pi(_z, task) _z = self.model.next(_z, pi_actions[t], task) - pi_actions[-1] = self.model.pi(_z, task)[1] + pi_actions[-1], _ = self.model.pi(_z, task) # Initialize state and parameters z = z.repeat(self.cfg.num_samples, 1) @@ -202,20 +204,27 @@ class TDMPC2(torch.nn.Module): Returns: float: Loss of the policy update. """ - _, pis, log_pis, _ = self.model.pi(zs, task) - qs = self.model.Q(zs, pis, task, return_type='avg', detach=True) + action, info = self.model.pi(zs, task) + qs = self.model.Q(zs, action, task, return_type='avg', detach=True) self.scale.update(qs[0]) qs = self.scale(qs) # Loss is a weighted sum of Q-values rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) - pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() + pi_loss = (-(info["entropy_scale"] * info["entropy"] + qs).mean(dim=(1,2)) * rho).mean() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) self.pi_optim.step() self.pi_optim.zero_grad(set_to_none=True) - return pi_loss.detach(), pi_grad_norm + info = TensorDict({ + "pi_loss": pi_loss, + "pi_grad_norm": pi_grad_norm, + "pi_entropy": info["entropy"], + "pi_entropy_scale": info["entropy_scale"], + "pi_scale": self.scale.value, + }) + return info @torch.no_grad() def _td_target(self, next_z, reward, task): @@ -230,9 +239,9 @@ class TDMPC2(torch.nn.Module): Returns: torch.Tensor: TD-target. """ - pi = self.model.pi(next_z, task)[1] + action, _ = self.model.pi(next_z, task) discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount - return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) + return reward + discount * self.model.Q(next_z, action, task, return_type='min', target=True) def _update(self, obs, action, reward, task=None): # Compute targets @@ -281,23 +290,22 @@ class TDMPC2(torch.nn.Module): self.optim.zero_grad(set_to_none=True) # Update policy - pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task) + pi_info = self.update_pi(zs.detach(), task) # Update target Q-functions self.model.soft_update_target_Q() # Return training statistics self.model.eval() - return TensorDict({ + info = TensorDict({ "consistency_loss": consistency_loss, "reward_loss": reward_loss, "value_loss": value_loss, - "pi_loss": pi_loss, "total_loss": total_loss, "grad_norm": grad_norm, - "pi_grad_norm": pi_grad_norm, - "pi_scale": self.scale.value, - }).detach().mean() + }) + info.update(pi_info) + return info.detach().mean() def update(self, buffer): """