Use torch.compiler.cudagraph_mark_step_begin() in eval

This commit is contained in:
Vincent Moens
2024-10-26 00:32:16 +01:00
committed by GitHub
parent 836547d76f
commit fad0d1be03

View File

@@ -31,6 +31,7 @@ class OnlineTrainer(Trainer):
if self.cfg.save_video:
self.logger.video.init(self.env, enabled=(i==0))
while not done:
torch.compiler.cudagraph_mark_step_begin()
action = self.agent.act(obs, t0=t==0, eval_mode=True)
obs, reward, done, info = self.env.step(action)
ep_reward += reward