From c16f2557bb6bb7f3e33b761e226c178536b4f65e Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 7 Jan 2024 11:52:53 -0800 Subject: [PATCH] support distributed training --- tdmpc2/trainer/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tdmpc2/trainer/base.py b/tdmpc2/trainer/base.py index 88f0167..3c4d0b1 100755 --- a/tdmpc2/trainer/base.py +++ b/tdmpc2/trainer/base.py @@ -8,8 +8,8 @@ class Trainer: self.buffer = buffer self.logger = logger if cfg.rank == 0: - print('Architecture:', self.agent.model) print("Learnable parameters: {:,}".format(self.agent.model.total_params)) + print('Architecture:', self.agent.model) def eval(self): """Evaluate a TD-MPC2 agent."""