diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index a6bf1be..cfc3ed7 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -175,6 +175,7 @@ class WorldModel(nn.Module): if z.dim() == 2: # z (batch_size, latent_dim) -> (batch_size, action_dim, latent_dim) z = z.unsqueeze(1).expand(-1, self.cfg.action_dim, -1) + actions = actions.repeat(z.shape[0], 1, 1) elif z.dim() == 3: # z (seq_len, batch_size, latent_dim) -> (seq_len, batch_size, action_dim, latent_dim) z = z.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1) diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index 6718a82..ab0f41b 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -30,7 +30,7 @@ exp_name: default data_dir: ??? # planning -mpc: false +mpc: true iterations: 6 num_samples: 512 num_elites: 64 diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 82f4313..6e2e780 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -121,7 +121,10 @@ class TDMPC2(torch.nn.Module): G = G + discount * reward discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount = discount * discount_update - return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') + pi = self.model.pi(z, task)[1] + if self.cfg.action == 'discrete': + pi = pi.squeeze(1) # TODO: this is a bit hacky + return G + discount * self.model.Q(z, pi, task, return_type='avg') @torch.no_grad() def _plan(self, obs, t0=False, eval_mode=False, task=None): @@ -143,20 +146,46 @@ class TDMPC2(torch.nn.Module): pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) _z = z.repeat(self.cfg.num_pi_trajs, 1) for t in range(self.cfg.horizon-1): - pi_actions[t] = self.model.pi(_z, task)[1] + action = self.model.pi(_z, task)[1] + if self.cfg.action == 'discrete': + action = action.squeeze(1) + pi_actions[t] = action _z = self.model.next(_z, pi_actions[t], task) - pi_actions[-1] = self.model.pi(_z, task)[1] + action = self.model.pi(_z, task)[1] + if self.cfg.action == 'discrete': + action = action.squeeze(1) + pi_actions[-1] = action # Initialize state and parameters z = z.repeat(self.cfg.num_samples, 1) - mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) - std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) - if not t0: - mean[:-1] = self._prev_mean[1:] + if self.cfg.action == 'continuous': + mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) + std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) + if not t0: + mean[:-1] = self._prev_mean[1:] actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) if self.cfg.num_pi_trajs > 0: actions[:, :self.cfg.num_pi_trajs] = pi_actions + # Random shooting + if self.cfg.action == 'discrete': + # Sample actions + actions_sample = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs), device=actions.device) + actions[:, self.cfg.num_pi_trajs:] = math.int_to_one_hot(actions_sample, self.cfg.action_dim) + + # Compute elite actions + value = self._estimate_value(z, actions, task).nan_to_num(0) + elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices + elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs] + + # Sample action according to score + max_value = elite_value.max(0).values + score = torch.exp(self.cfg.temperature*(elite_value - max_value)) + score = score / score.sum(0) + rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs + actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) + return actions[0] + # Iterate MPPI for _ in range(self.cfg.iterations): @@ -191,6 +220,7 @@ class TDMPC2(torch.nn.Module): if not eval_mode: a = a + std * torch.randn(self.cfg.action_dim, device=std.device) self._prev_mean.copy_(mean) + return a.clamp(-1, 1) def update_pi(self, zs, task):