refactor pi outputs
This commit is contained in:
@@ -15,7 +15,11 @@ class Ensemble(nn.Module):
|
||||
self.params = from_modules(*modules, as_module=True)
|
||||
with self.params[0].data.to("meta").to_module(modules[0]):
|
||||
self.module = deepcopy(modules[0])
|
||||
self._repr = str(modules)
|
||||
self._repr = str(modules[0])
|
||||
self._n = len(modules)
|
||||
|
||||
def __len__(self):
|
||||
return self._n
|
||||
|
||||
def _call(self, params, *args, **kwargs):
|
||||
with params.to_module(self.module):
|
||||
@@ -25,7 +29,7 @@ class Ensemble(nn.Module):
|
||||
return torch.vmap(self._call, (0, None), randomness="different")(self.params, *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return 'Vectorized ' + self._repr
|
||||
return f'Vectorized {len(self)}x ' + self._repr
|
||||
|
||||
|
||||
class ShiftAug(nn.Module):
|
||||
|
||||
@@ -154,11 +154,13 @@ class WorldModel(nn.Module):
|
||||
action = mean + eps * log_std.exp()
|
||||
mean, action, log_prob = math.squash(mean, action, log_prob)
|
||||
|
||||
entropy_scale = scaled_log_prob / (log_prob + 1e-8)
|
||||
info = TensorDict({
|
||||
"mean": mean,
|
||||
"log_std": log_std,
|
||||
"action_prob": 1.,
|
||||
"entropy": -log_prob,
|
||||
"entropy_scale": self.cfg.entropy_coef * scaled_log_prob / log_prob,
|
||||
"scaled_entropy": -log_prob * entropy_scale,
|
||||
})
|
||||
return action, info
|
||||
|
||||
|
||||
@@ -226,7 +226,7 @@ class TDMPC2(torch.nn.Module):
|
||||
|
||||
# Loss is a weighted sum of Q-values
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||
pi_loss = (-(info["entropy_scale"] * info["entropy"] + qs).mean(dim=(1,2)) * rho).mean()
|
||||
pi_loss = (-(self.cfg.entropy_coef * info["scaled_entropy"] + qs).mean(dim=(1,2)) * rho).mean()
|
||||
pi_loss.backward()
|
||||
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
||||
self.pi_optim.step()
|
||||
@@ -236,7 +236,7 @@ class TDMPC2(torch.nn.Module):
|
||||
"pi_loss": pi_loss,
|
||||
"pi_grad_norm": pi_grad_norm,
|
||||
"pi_entropy": info["entropy"],
|
||||
"pi_entropy_scale": info["entropy_scale"],
|
||||
"pi_scaled_entropy": info["scaled_entropy"],
|
||||
"pi_scale": self.scale.value,
|
||||
})
|
||||
return info
|
||||
|
||||
Reference in New Issue
Block a user