From 10a0be2724a75eb61dcf4049b3c5a9b84bbcc459 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 10 Nov 2024 23:16:32 -0800 Subject: [PATCH] fix indexing --- tdmpc2/tdmpc2.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index ff41dbf..597cc28 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -34,7 +34,7 @@ class TDMPC2(torch.nn.Module): self.discount = torch.tensor( [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' ) if self.cfg.multitask else self._get_discount(cfg.episode_length) - self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) + self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device)) if cfg.compile: print('Compiling update function with torch.compile...') self._update = torch.compile(self._update, mode="reduce-overhead") @@ -169,22 +169,23 @@ class TDMPC2(torch.nn.Module): # Compute elite actions value = self._estimate_value(z, actions, task).nan_to_num(0) elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices - elite_value, elite_actions = value[elite_idxs], actions[:, :, elite_idxs] + elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2)) + elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim)) # Update parameters max_value = elite_value.max(1).values score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1))) - score = score / score.sum(1) - mean = (score.unsqueeze(1) * elite_actions).sum(dim=2) / (score.sum(1) + 1e-9) - std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) / (score.sum(1) + 1e-9)).sqrt() + score = (score / score.sum(1, keepdim=True)) + mean = (score.unsqueeze(1) * elite_actions).sum(2) / (score.sum(1, keepdim=True) + 1e-9) + std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(2) / (score.sum(1, keepdim=True) + 1e-9)).sqrt() std = std.clamp(self.cfg.min_std, self.cfg.max_std) if self.cfg.multitask: mean = mean * self.model._action_masks[task] std = std * self.model._action_masks[task] # Select action - rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs - actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) + rand_idx = math.gumbel_softmax_sample(score.squeeze(2), dim=1) # gumbel_softmax_sample is compatible with cuda graphs + actions = elite_actions[torch.arange(self.cfg.num_envs), :, rand_idx] action, std = actions[:, 0], std[:, 0] if not eval_mode: action = action + std * torch.randn(self.cfg.action_dim, device=std.device)