From 804f9b3949ba0903fbed5014673d3212fdc28aa3 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Tue, 24 Dec 2024 03:05:00 -0800 Subject: [PATCH] refactor pi outputs --- tdmpc2/common/layers.py | 8 ++++++-- tdmpc2/common/world_model.py | 4 +++- tdmpc2/tdmpc2.py | 4 ++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tdmpc2/common/layers.py b/tdmpc2/common/layers.py index 5890d8d..dd34463 100644 --- a/tdmpc2/common/layers.py +++ b/tdmpc2/common/layers.py @@ -15,7 +15,11 @@ class Ensemble(nn.Module): self.params = from_modules(*modules, as_module=True) with self.params[0].data.to("meta").to_module(modules[0]): self.module = deepcopy(modules[0]) - self._repr = str(modules) + self._repr = str(modules[0]) + self._n = len(modules) + + def __len__(self): + return self._n def _call(self, params, *args, **kwargs): with params.to_module(self.module): @@ -25,7 +29,7 @@ class Ensemble(nn.Module): return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs) def __repr__(self): - return 'Vectorized ' + self._repr + return f'Vectorized {len(self)}x ' + self._repr class ShiftAug(nn.Module): diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 23b5d40..4babde7 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -154,11 +154,13 @@ class WorldModel(nn.Module): action = mean + eps * log_std.exp() mean, action, log_prob = math.squash(mean, action, log_prob) + entropy_scale = scaled_log_prob / (log_prob + 1e-8) info = TensorDict({ "mean": mean, "log_std": log_std, + "action_prob": 1., "entropy": -log_prob, - "entropy_scale": self.cfg.entropy_coef * scaled_log_prob / log_prob, + "scaled_entropy": -log_prob * entropy_scale, }) return action, info diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index e4f56bf..ef971c5 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -226,7 +226,7 @@ class TDMPC2(torch.nn.Module): # Loss is a weighted sum of Q-values rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) - pi_loss = (-(info["entropy_scale"] * info["entropy"] + qs).mean(dim=(1,2)) * rho).mean() + pi_loss = (-(self.cfg.entropy_coef * info["scaled_entropy"] + qs).mean(dim=(1,2)) * rho).mean() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) self.pi_optim.step() @@ -236,7 +236,7 @@ class TDMPC2(torch.nn.Module): "pi_loss": pi_loss, "pi_grad_norm": pi_grad_norm, "pi_entropy": info["entropy"], - "pi_entropy_scale": info["entropy_scale"], + "pi_scaled_entropy": info["scaled_entropy"], "pi_scale": self.scale.value, }) return info