fix merge error
This commit is contained in:
@@ -137,12 +137,12 @@ class TDMPC2(torch.nn.Module):
|
||||
return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')
|
||||
|
||||
@torch.no_grad()
|
||||
def _plan(self, z, t0=False, eval_mode=False, task=None):
|
||||
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||
"""
|
||||
Plan a sequence of actions using the learned world model.
|
||||
|
||||
Args:
|
||||
z (torch.Tensor): Latent state from which to plan.
|
||||
obs (torch.Tensor): Observation from which to plan.
|
||||
t0 (bool): Whether this is the first observation in the episode.
|
||||
eval_mode (bool): Whether to use the mean of the action distribution.
|
||||
task (Torch.Tensor): Task index (only used for multi-task experiments).
|
||||
@@ -150,6 +150,8 @@ class TDMPC2(torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Action to take in the environment.
|
||||
"""
|
||||
z = self.model.encode(obs, task)
|
||||
|
||||
# Sample policy trajectories
|
||||
if self.cfg.num_pi_trajs > 0:
|
||||
pi_actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||
|
||||
Reference in New Issue
Block a user