From fad0d1be0361d15997c9cf8594f1f47dad226ca4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 26 Oct 2024 00:32:16 +0100 Subject: [PATCH 1/2] Use torch.compiler.cudagraph_mark_step_begin() in eval --- tdmpc2/trainer/online_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 3a47542..0d2f062 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -31,6 +31,7 @@ class OnlineTrainer(Trainer): if self.cfg.save_video: self.logger.video.init(self.env, enabled=(i==0)) while not done: + torch.compiler.cudagraph_mark_step_begin() action = self.agent.act(obs, t0=t==0, eval_mode=True) obs, reward, done, info = self.env.step(action) ep_reward += reward From 3b5f67592ccd8d4c511bea490e97b33cd67a7a7b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 26 Oct 2024 00:33:27 +0100 Subject: [PATCH 2/2] Update offline_trainer.py --- tdmpc2/trainer/offline_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tdmpc2/trainer/offline_trainer.py b/tdmpc2/trainer/offline_trainer.py index 1bace8e..89f1c20 100755 --- a/tdmpc2/trainer/offline_trainer.py +++ b/tdmpc2/trainer/offline_trainer.py @@ -27,6 +27,7 @@ class OfflineTrainer(Trainer): for _ in range(self.cfg.eval_episodes): obs, done, ep_reward, t = self.env.reset(task_idx), False, 0, 0 while not done: + torch.compiler.cudagraph_mark_step_begin() action = self.agent.act(obs, t0=t==0, eval_mode=True, task=task_idx) obs, reward, done, info = self.env.step(action) ep_reward += reward