diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 6029b4f..3b44fed 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -137,12 +137,12 @@ class TDMPC2(torch.nn.Module): return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg') @torch.no_grad() - def _plan(self, z, t0=False, eval_mode=False, task=None): + def _plan(self, obs, t0=False, eval_mode=False, task=None): """ Plan a sequence of actions using the learned world model. Args: - z (torch.Tensor): Latent state from which to plan. + obs (torch.Tensor): Observation from which to plan. t0 (bool): Whether this is the first observation in the episode. eval_mode (bool): Whether to use the mean of the action distribution. task (Torch.Tensor): Task index (only used for multi-task experiments). @@ -150,6 +150,8 @@ class TDMPC2(torch.nn.Module): Returns: torch.Tensor: Action to take in the environment. """ + z = self.model.encode(obs, task) + # Sample policy trajectories if self.cfg.num_pi_trajs > 0: pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)