diff --git a/tdmpc2/trainer/base.py b/tdmpc2/trainer/base.py index 88f0167..3c4d0b1 100755 --- a/tdmpc2/trainer/base.py +++ b/tdmpc2/trainer/base.py @@ -8,8 +8,8 @@ class Trainer: self.buffer = buffer self.logger = logger if cfg.rank == 0: - print('Architecture:', self.agent.model) print("Learnable parameters: {:,}".format(self.agent.model.total_params)) + print('Architecture:', self.agent.model) def eval(self): """Evaluate a TD-MPC2 agent."""