fix merge error

This commit is contained in:
Nicklas Hansen
2025-05-20 14:09:13 -07:00
parent 6116eb3fa5
commit a586d8f393

View File

@@ -137,12 +137,12 @@ class TDMPC2(torch.nn.Module):
return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg') return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')
@torch.no_grad() @torch.no_grad()
def _plan(self, z, t0=False, eval_mode=False, task=None): def _plan(self, obs, t0=False, eval_mode=False, task=None):
""" """
Plan a sequence of actions using the learned world model. Plan a sequence of actions using the learned world model.
Args: Args:
z (torch.Tensor): Latent state from which to plan. obs (torch.Tensor): Observation from which to plan.
t0 (bool): Whether this is the first observation in the episode. t0 (bool): Whether this is the first observation in the episode.
eval_mode (bool): Whether to use the mean of the action distribution. eval_mode (bool): Whether to use the mean of the action distribution.
task (Torch.Tensor): Task index (only used for multi-task experiments). task (Torch.Tensor): Task index (only used for multi-task experiments).
@@ -150,6 +150,8 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
torch.Tensor: Action to take in the environment. torch.Tensor: Action to take in the environment.
""" """
z = self.model.encode(obs, task)
# Sample policy trajectories # Sample policy trajectories
if self.cfg.num_pi_trajs > 0: if self.cfg.num_pi_trajs > 0:
pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device) pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)