update compile print
This commit is contained in:
@@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||||
if cfg.compile:
|
if cfg.compile:
|
||||||
print('compiling - update')
|
print('Compiling update function with torch.compile...')
|
||||||
self._update = torch.compile(self._update, mode="reduce-overhead")
|
self._update = torch.compile(self._update, mode="reduce-overhead")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Reference in New Issue
Block a user