Merge branch 'speedups' of github.com:nicklashansen/tdmpc2 into speedups
This commit is contained in:
@@ -27,6 +27,7 @@ class OfflineTrainer(Trainer):
|
||||
for _ in range(self.cfg.eval_episodes):
|
||||
obs, done, ep_reward, t = self.env.reset(task_idx), False, 0, 0
|
||||
while not done:
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
action = self.agent.act(obs, t0=t==0, eval_mode=True, task=task_idx)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
ep_reward += reward
|
||||
|
||||
@@ -31,6 +31,7 @@ class OnlineTrainer(Trainer):
|
||||
if self.cfg.save_video:
|
||||
self.logger.video.init(self.env, enabled=(i==0))
|
||||
while not done:
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
action = self.agent.act(obs, t0=t==0, eval_mode=True)
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
ep_reward += reward
|
||||
|
||||
Reference in New Issue
Block a user