diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 0e0b68c..3b06499 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -106,9 +106,8 @@ class TDMPC2(torch.nn.Module): action = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) else: z = self.model.encode(obs, task) - action = self.model.pi(z, task)[int(not eval_mode)][0] - if self.cfg.action_space == 'discrete': - action = action.squeeze(0) # TODO: this is a bit hacky + select_idx = int(not eval_mode or self.cfg.action_space == 'discrete') + action = self.model.pi(z, task)[select_idx][0] return action.cpu() @torch.no_grad() @@ -126,6 +125,33 @@ class TDMPC2(torch.nn.Module): pi = pi.squeeze(1) # TODO: this is a bit hacky return G + discount * self.model.Q(z, pi, task, return_type='avg') + @torch.no_grad() + def _sample_policy(self, z, task): + """Sample trajectories from the policy prior.""" + pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) + for t in range(self.cfg.horizon-1): + action = self.model.pi(z, task)[1] + if self.cfg.action_space == 'discrete': + action = action.squeeze(1) + pi_actions[t] = action + z = self.model.next(z, pi_actions[t], task) + action = self.model.pi(z, task)[1] + if self.cfg.action_space == 'discrete': + action = action.squeeze(1) + pi_actions[-1] = action + return pi_actions + + @torch.no_grad() + def _sample_actions(self, n, mean=None, std=None): + """Sample actions from a Gaussian or Categorical distribution.""" + if self.cfg.action_space == 'discrete': + actions = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, n), device=self.device) + actions = math.int_to_one_hot(actions, self.cfg.action_dim) + else: + r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) + actions = (mean.unsqueeze(1) + std.unsqueeze(1) * r).clamp(-1, 1) + return actions + @torch.no_grad() def _plan(self, obs, t0=False, eval_mode=False, task=None): """ @@ -140,88 +166,61 @@ class TDMPC2(torch.nn.Module): Returns: torch.Tensor: Action to take in the environment. """ - # Sample policy trajectories + # Encode observation z = self.model.encode(obs, task) - if self.cfg.num_pi_trajs > 0: - pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) - _z = z.repeat(self.cfg.num_pi_trajs, 1) - for t in range(self.cfg.horizon-1): - action = self.model.pi(_z, task)[1] - if self.cfg.action_space == 'discrete': - action = action.squeeze(1) - pi_actions[t] = action - _z = self.model.next(_z, pi_actions[t], task) - action = self.model.pi(_z, task)[1] - if self.cfg.action_space == 'discrete': - action = action.squeeze(1) - pi_actions[-1] = action - - # Initialize state and parameters z = z.repeat(self.cfg.num_samples, 1) + + # Initialize parameters if self.cfg.action_space == 'continuous': mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) if not t0: mean[:-1] = self._prev_mean[1:] + else: + mean, std = None, None actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) + + # Sample policy trajectories if self.cfg.num_pi_trajs > 0: - actions[:, :self.cfg.num_pi_trajs] = pi_actions - - # Random shooting - if self.cfg.action_space == 'discrete': - # Sample actions - actions_sample = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs), device=actions.device) - actions[:, self.cfg.num_pi_trajs:] = math.int_to_one_hot(actions_sample, self.cfg.action_dim) - - # Compute elite actions - value = self._estimate_value(z, actions, task).nan_to_num(0) - elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices - elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] - - # Sample action according to score - max_value = elite_value.max(0).values - score = torch.exp(self.cfg.temperature*(elite_value - max_value)) - score = score / score.sum(0) - rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs - actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) - return actions[0] + actions[:, :self.cfg.num_pi_trajs] = self._sample_policy(z[:self.cfg.num_pi_trajs], task) # Iterate MPPI for _ in range(self.cfg.iterations): - # Sample actions - r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device) - actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r - actions_sample = actions_sample.clamp(-1, 1) - actions[:, self.cfg.num_pi_trajs:] = actions_sample + # Sample random actions + actions[:, self.cfg.num_pi_trajs:] = self._sample_actions(self.cfg.num_samples-self.cfg.num_pi_trajs, mean, std) if self.cfg.multitask: actions = actions * self.model._action_masks[task] - - # Compute elite actions + + # Select elites and compute scores value = self._estimate_value(z, actions, task).nan_to_num(0) elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] - - # Update parameters max_value = elite_value.max(0).values score = torch.exp(self.cfg.temperature*(elite_value - max_value)) score = score / score.sum(0) - mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9) - std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt() - std = std.clamp(self.cfg.min_std, self.cfg.max_std) - if self.cfg.multitask: - mean = mean * self.model._action_masks[task] - std = std * self.model._action_masks[task] + # Update parameters + if self.cfg.action_space == 'continuous': + mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9) + std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt() + std = std.clamp(self.cfg.min_std, self.cfg.max_std) + if self.cfg.multitask: + mean = mean * self.model._action_masks[task] + std = std * self.model._action_masks[task] + else: + break + # Select action - rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs - actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) - a, std = actions[0], std[0] - if not eval_mode: - a = a + std * torch.randn(self.cfg.action_dim, device=std.device) - self._prev_mean.copy_(mean) - - return a.clamp(-1, 1) + rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs + action = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)[0] + if self.cfg.action_space == 'continuous': + if not eval_mode: + action = action + std[0] * torch.randn(self.cfg.action_dim, device=std.device) + self._prev_mean.copy_(mean) + action = action.clamp(-1, 1) + + return action def update_pi(self, zs, task): """