diff --git a/tdmpc2/trainer/base.py b/tdmpc2/trainer/base.py index 3c4d0b1..88f0167 100755 --- a/tdmpc2/trainer/base.py +++ b/tdmpc2/trainer/base.py @@ -8,8 +8,8 @@ class Trainer: self.buffer = buffer self.logger = logger if cfg.rank == 0: - print("Learnable parameters: {:,}".format(self.agent.model.total_params)) print('Architecture:', self.agent.model) + print("Learnable parameters: {:,}".format(self.agent.model.total_params)) def eval(self): """Evaluate a TD-MPC2 agent."""