diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 0e1a7bd..ff41dbf 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -99,7 +99,7 @@ class TDMPC2(torch.nn.Module): Returns: torch.Tensor: Action to take in the environment. """ - obs = obs.to(self.device, non_blocking=True).unsqueeze(0) + obs = obs.to(self.device, non_blocking=True) if task is not None: task = torch.tensor([task], device=self.device) if self.cfg.mpc: @@ -154,7 +154,7 @@ class TDMPC2(torch.nn.Module): actions = torch.empty(self.cfg.num_envs, self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device) if self.cfg.num_pi_trajs > 0: actions[:, :, :self.cfg.num_pi_trajs] = pi_actions - + # Iterate MPPI for _ in range(self.cfg.iterations):