clean up
This commit is contained in:
@@ -147,17 +147,9 @@ class TDMPC2:
|
||||
|
||||
# Compute elite actions
|
||||
value = self._estimate_value(z, actions, task).nan_to_num_(0)
|
||||
|
||||
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
|
||||
elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
|
||||
elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim))
|
||||
|
||||
# vectorized version
|
||||
# elite_value, elite_actions = [], []
|
||||
# for i in range(self.cfg.num_envs):
|
||||
# elite_value.append(value[i, elite_idxs[i]])
|
||||
# elite_actions.append(actions[i, elite_idxs[i]])
|
||||
# elite_value = torch.stack(elite_value, dim=0)
|
||||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(1)[0]
|
||||
|
||||
Reference in New Issue
Block a user