diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 8e6b435..e4d8ec2 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module): ) if self.cfg.multitask else self._get_discount(cfg.episode_length) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) if cfg.compile: - print('compiling - update') + print('Compiling update function with torch.compile...') self._update = torch.compile(self._update, mode="reduce-overhead") @property