fix indexing
This commit is contained in:
@@ -34,7 +34,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
self.discount = torch.tensor(
|
self.discount = torch.tensor(
|
||||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.num_envs, self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||||
if cfg.compile:
|
if cfg.compile:
|
||||||
print('Compiling update function with torch.compile...')
|
print('Compiling update function with torch.compile...')
|
||||||
self._update = torch.compile(self._update, mode="reduce-overhead")
|
self._update = torch.compile(self._update, mode="reduce-overhead")
|
||||||
@@ -169,22 +169,23 @@ class TDMPC2(torch.nn.Module):
|
|||||||
# Compute elite actions
|
# Compute elite actions
|
||||||
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
||||||
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
|
elite_idxs = torch.topk(value.squeeze(2), self.cfg.num_elites, dim=1).indices
|
||||||
elite_value, elite_actions = value[elite_idxs], actions[:, :, elite_idxs]
|
elite_value = torch.gather(value, 1, elite_idxs.unsqueeze(2))
|
||||||
|
elite_actions = torch.gather(actions, 2, elite_idxs.unsqueeze(1).unsqueeze(3).expand(-1, self.cfg.horizon, -1, self.cfg.action_dim))
|
||||||
|
|
||||||
# Update parameters
|
# Update parameters
|
||||||
max_value = elite_value.max(1).values
|
max_value = elite_value.max(1).values
|
||||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
|
score = torch.exp(self.cfg.temperature*(elite_value - max_value.unsqueeze(1)))
|
||||||
score = score / score.sum(1)
|
score = (score / score.sum(1, keepdim=True))
|
||||||
mean = (score.unsqueeze(1) * elite_actions).sum(dim=2) / (score.sum(1) + 1e-9)
|
mean = (score.unsqueeze(1) * elite_actions).sum(2) / (score.sum(1, keepdim=True) + 1e-9)
|
||||||
std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(dim=2) / (score.sum(1) + 1e-9)).sqrt()
|
std = ((score.unsqueeze(1) * (elite_actions - mean.unsqueeze(2)) ** 2).sum(2) / (score.sum(1, keepdim=True) + 1e-9)).sqrt()
|
||||||
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
|
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
mean = mean * self.model._action_masks[task]
|
mean = mean * self.model._action_masks[task]
|
||||||
std = std * self.model._action_masks[task]
|
std = std * self.model._action_masks[task]
|
||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
rand_idx = math.gumbel_softmax_sample(score.squeeze(2), dim=1) # gumbel_softmax_sample is compatible with cuda graphs
|
||||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
actions = elite_actions[torch.arange(self.cfg.num_envs), :, rand_idx]
|
||||||
action, std = actions[:, 0], std[:, 0]
|
action, std = actions[:, 0], std[:, 0]
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
action = action + std * torch.randn(self.cfg.action_dim, device=std.device)
|
action = action + std * torch.randn(self.cfg.action_dim, device=std.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user