From 8280b82d5cac7c7e534fbd89f0cf8c0c74931f5f Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Mon, 11 Nov 2024 22:36:40 -0800 Subject: [PATCH] argmax policy works --- tdmpc2/common/world_model.py | 34 ++++++++++++++++++++++++-------- tdmpc2/tdmpc2.py | 9 ++++++++- tdmpc2/trainer/online_trainer.py | 5 +++++ 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 3482438..a6bf1be 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -156,16 +156,34 @@ class WorldModel(nn.Module): with logits predicted by a neural network. """ # Categorical policy prior - logits = self._pi(z) - policy_dist = Categorical(logits=logits) - action = policy_dist.sample() + # logits = self._pi(z) + # policy_dist = Categorical(logits=logits) + # action = policy_dist.sample() + # action = math.int_to_one_hot(action, self.cfg.action_dim) + + # # Action probabilities for calculating the adapted soft-Q loss + # action_probs = policy_dist.probs + # log_prob = F.log_softmax(logits, dim=-1) + + # return action, action, log_prob, action_probs + + # Argmax policy + # enumerate all possible one-hot actions + # and return the one with the highest Q-value + # for the given state. + actions = torch.eye(self.cfg.action_dim, device=z.device).unsqueeze(0) + if z.dim() == 2: + # z (batch_size, latent_dim) -> (batch_size, action_dim, latent_dim) + z = z.unsqueeze(1).expand(-1, self.cfg.action_dim, -1) + elif z.dim() == 3: + # z (seq_len, batch_size, latent_dim) -> (seq_len, batch_size, action_dim, latent_dim) + z = z.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1) + actions = actions.unsqueeze(0).repeat(z.shape[0], z.shape[1], 1, 1) + Q = self.Q(z, actions, task, return_type='min') + action = Q.argmax(dim=-2) action = math.int_to_one_hot(action, self.cfg.action_dim) - # Action probabilities for calculating the adapted soft-Q loss - action_probs = policy_dist.probs - log_prob = F.log_softmax(logits, dim=-1) - - return action, action, log_prob, action_probs + return action, action, None, None def pi(self, z, task): diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 1deb61b..82f4313 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -107,6 +107,8 @@ class TDMPC2(torch.nn.Module): else: z = self.model.encode(obs, task) action = self.model.pi(z, task)[int(not eval_mode)][0] + if self.cfg.action == 'discrete': + action = action.squeeze(0) # TODO: this is a bit hacky return action.cpu() @torch.no_grad() @@ -234,6 +236,8 @@ class TDMPC2(torch.nn.Module): torch.Tensor: TD-target. """ pi = self.model.pi(next_z, task)[1] + if self.cfg.action == 'discrete': + pi = pi.squeeze(2) # TODO: this is a bit hacky discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) @@ -284,7 +288,10 @@ class TDMPC2(torch.nn.Module): self.optim.zero_grad(set_to_none=True) # Update policy - pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task) + if self.cfg.action == 'continuous': + pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task) + else: + pi_loss, pi_grad_norm = 0., 0. # Update target Q-functions self.model.soft_update_target_Q() diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 7c8f3c5..097cd61 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -98,6 +98,11 @@ class OnlineTrainer(Trainer): action = self.agent.act(obs, t0=len(self._tds)==1) else: action = self.env.rand_act() + if self.cfg.action == 'discrete': + # exploration schedule + # minimum 0.01, maximum 0.05, anneal over 20k steps + if torch.rand(1) < 0.01 + (0.05 - 0.01) * min(1, self._step / 20000): + action = self.env.rand_act() obs, reward, done, info = self.env.step(action) self._tds.append(self.to_td(obs, action, reward))