diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 3925359..9d49cf8 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -85,7 +85,10 @@ class TDMPC2: if task is not None: task = torch.tensor([task], device=self.device) z = self.model.encode(obs, task) - a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) + if self.cfg.mpc: + a = self.plan(z, t0=t0, eval_mode=eval_mode, task=task) + else: + a = self.model.pi(z, task)[int(not eval_mode)][0] return a.cpu() @torch.no_grad()