clean up
This commit is contained in:
@@ -147,18 +147,10 @@ class TDMPC2:
|
|||||||
|
|
||||||
# Compute elite actions
|
# Compute elite actions
|
||||||
value = self._estimate_value(z, actions, task).nan_to_num_(0)
|
value = self._estimate_value(z, actions, task).nan_to_num_(0)
|
||||||
|
|
||||||
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
|
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
|
||||||
elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
|
elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
|
||||||
elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim))
|
elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim))
|
||||||
|
|
||||||
# vectorized version
|
|
||||||
# elite_value, elite_actions = [], []
|
|
||||||
# for i in range(self.cfg.num_envs):
|
|
||||||
# elite_value.append(value[i, elite_idxs[i]])
|
|
||||||
# elite_actions.append(actions[i, elite_idxs[i]])
|
|
||||||
# elite_value = torch.stack(elite_value, dim=0)
|
|
||||||
|
|
||||||
# Update parameters
|
# Update parameters
|
||||||
max_value = elite_value.max(1)[0]
|
max_value = elite_value.max(1)[0]
|
||||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
|
score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
|
||||||
|
|||||||
Reference in New Issue
Block a user