This commit is contained in:
Nicklas Hansen
2024-02-11 14:44:16 -08:00
parent 51d6b8d7a9
commit 9dd3e673c4

View File

@@ -147,18 +147,10 @@ class TDMPC2:
# Compute elite actions # Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num_(0) value = self._estimate_value(z, actions, task).nan_to_num_(0)
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2)) elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim)) elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim))
# vectorized version
# elite_value, elite_actions = [], []
# for i in range(self.cfg.num_envs):
# elite_value.append(value[i, elite_idxs[i]])
# elite_actions.append(actions[i, elite_idxs[i]])
# elite_value = torch.stack(elite_value, dim=0)
# Update parameters # Update parameters
max_value = elite_value.max(1)[0] max_value = elite_value.max(1)[0]
score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1))) score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))