fix indexing

This commit is contained in:
Nicklas Hansen
2024-11-10 23:16:32 -08:00
parent ad2342e258
commit 10a0be2724

View File

@@ -34,7 +34,7 @@ class TDMPC2(torch.nn.Module):
self.discount = torch.tensor( self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile: if cfg.compile:
print('Compiling update function with torch.compile...') print('Compiling update function with torch.compile...')
self._update = torch.compile(self._update, mode="reduce-overhead") self._update = torch.compile(self._update, mode="reduce-overhead")
@@ -169,22 +169,23 @@ class TDMPC2(torch.nn.Module):
# Compute elite actions # Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num(0) value = self._estimate_value(z, actions, task).nan_to_num(0)
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
elite_value, elite_actions = value[elite_idxs], actions[:, :, elite_idxs] elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim))
# Update parameters # Update parameters
max_value = elite_value.max(1).values max_value = elite_value.max(1).values
score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1))) score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
score = score / score.sum(1) score = (score / score.sum(1, keepdim=True))
mean = (score.unsqueeze(1) * elite_actions).sum(dim=2) / (score.sum(1) + 1e-9) mean = (score.unsqueeze(1) * elite_actions).sum(2) / (score.sum(1, keepdim=True) + 1e-9)
std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) / (score.sum(1) + 1e-9)).sqrt() std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(2) / (score.sum(1, keepdim=True) + 1e-9)).sqrt()
std = std.clamp(self.cfg.min_std, self.cfg.max_std) std = std.clamp(self.cfg.min_std, self.cfg.max_std)
if self.cfg.multitask: if self.cfg.multitask:
mean = mean * self.model._action_masks[task] mean = mean * self.model._action_masks[task]
std = std * self.model._action_masks[task] std = std * self.model._action_masks[task]
# Select action # Select action
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs rand_idx = math.gumbel_softmax_sample(score.squeeze(2), dim=1) # gumbel_softmax_sample is compatible with cuda graphs
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) actions = elite_actions[torch.arange(self.cfg.num_envs), :, rand_idx]
action, std = actions[:, 0], std[:, 0] action, std = actions[:, 0], std[:, 0]
if not eval_mode: if not eval_mode:
action = action + std * torch.randn(self.cfg.action_dim, device=std.device) action = action + std * torch.randn(self.cfg.action_dim, device=std.device)