refactor pi outputs

This commit is contained in:
Nicklas Hansen
2024-12-24 03:05:00 -08:00
parent 66f8c21f58
commit 804f9b3949
3 changed files with 11 additions and 5 deletions

View File

@@ -15,7 +15,11 @@ class Ensemble(nn.Module):
self.params = from_modules(*modules, as_module=True)
with self.params[0].data.to("meta").to_module(modules[0]):
self.module = deepcopy(modules[0])
self._repr = str(modules)
self._repr = str(modules[0])
self._n = len(modules)
def __len__(self):
return self._n
def _call(self, params, *args, **kwargs):
with params.to_module(self.module):
@@ -25,7 +29,7 @@ class Ensemble(nn.Module):
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
def __repr__(self):
return 'Vectorized ' + self._repr
return f'Vectorized {len(self)}x ' + self._repr
class ShiftAug(nn.Module):

View File

@@ -154,11 +154,13 @@ class WorldModel(nn.Module):
action = mean + eps * log_std.exp()
mean, action, log_prob = math.squash(mean, action, log_prob)
entropy_scale = scaled_log_prob / (log_prob + 1e-8)
info = TensorDict({
"mean": mean,
"log_std": log_std,
"action_prob": 1.,
"entropy": -log_prob,
"entropy_scale": self.cfg.entropy_coef * scaled_log_prob / log_prob,
"scaled_entropy": -log_prob * entropy_scale,
})
return action, info

View File

@@ -226,7 +226,7 @@ class TDMPC2(torch.nn.Module):
# Loss is a weighted sum of Q-values
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = (-(info["entropy_scale"] * info["entropy"] + qs).mean(dim=(1,2)) * rho).mean()
pi_loss = (-(self.cfg.entropy_coef * info["scaled_entropy"] + qs).mean(dim=(1,2)) * rho).mean()
pi_loss.backward()
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step()
@@ -236,7 +236,7 @@ class TDMPC2(torch.nn.Module):
"pi_loss": pi_loss,
"pi_grad_norm": pi_grad_norm,
"pi_entropy": info["entropy"],
"pi_entropy_scale": info["entropy_scale"],
"pi_scaled_entropy": info["scaled_entropy"],
"pi_scale": self.scale.value,
})
return info