diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index d0a54c4..cca51e1 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -147,17 +147,9 @@ class TDMPC2: # Compute elite actions value = self._estimate_value(z, actions, task).nan_to_num_(0) - elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2)) elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim)) - - # vectorized version - # elite_value, elite_actions = [], [] - # for i in range(self.cfg.num_envs): - # elite_value.append(value[i, elite_idxs[i]]) - # elite_actions.append(actions[i, elite_idxs[i]]) - # elite_value = torch.stack(elite_value, dim=0) # Update parameters max_value = elite_value.max(1)[0]