fix indexing
This commit is contained in:
@@ -34,7 +34,7 @@ class TDMPC2(torch.nn.Module):
|
||||
self.discount = torch.tensor(
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||
if cfg.compile:
|
||||
print('Compiling update function with torch.compile...')
|
||||
self._update = torch.compile(self._update, mode="reduce-overhead")
|
||||
@@ -169,22 +169,23 @@ class TDMPC2(torch.nn.Module):
|
||||
# Compute elite actions
|
||||
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
||||
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, :, elite_idxs]
|
||||
elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
|
||||
elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim))
|
||||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(1).values
|
||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
|
||||
score = score / score.sum(1)
|
||||
mean = (score.unsqueeze(1) * elite_actions).sum(dim=2) / (score.sum(1) + 1e-9)
|
||||
std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) / (score.sum(1) + 1e-9)).sqrt()
|
||||
score = (score / score.sum(1, keepdim=True))
|
||||
mean = (score.unsqueeze(1) * elite_actions).sum(2) / (score.sum(1, keepdim=True) + 1e-9)
|
||||
std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(2) / (score.sum(1, keepdim=True) + 1e-9)).sqrt()
|
||||
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
|
||||
if self.cfg.multitask:
|
||||
mean = mean * self.model._action_masks[task]
|
||||
std = std * self.model._action_masks[task]
|
||||
|
||||
# Select action
|
||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(2), dim=1) # gumbel_softmax_sample is compatible with cuda graphs
|
||||
actions = elite_actions[torch.arange(self.cfg.num_envs), :, rand_idx]
|
||||
action, std = actions[:, 0], std[:, 0]
|
||||
if not eval_mode:
|
||||
action = action + std * torch.randn(self.cfg.action_dim, device=std.device)
|
||||
|
||||
Reference in New Issue
Block a user