From df8a465c8e137c652a142f6ad6cdf540d3a6a39a Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Tue, 10 Dec 2024 16:30:05 -0800 Subject: [PATCH] update offline trainer to use new torch.load api --- tdmpc2/trainer/offline_trainer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tdmpc2/trainer/offline_trainer.py b/tdmpc2/trainer/offline_trainer.py index cfe6d09..0d22fe8 100755 --- a/tdmpc2/trainer/offline_trainer.py +++ b/tdmpc2/trainer/offline_trainer.py @@ -38,13 +38,9 @@ class OfflineTrainer(Trainer): f'episode_reward+{self.cfg.tasks[task_idx]}': np.nanmean(ep_rewards), f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),}) return results - - def train(self): - """Train a TD-MPC2 agent.""" - assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \ - 'Offline training only supports multitask training with mt30 or mt80 task sets.' - - # Load data + + def _load_dataset(self): + """Load dataset for offline training.""" fp = Path(os.path.join(self.cfg.data_dir, '*.pt')) fps = sorted(glob(str(fp))) assert len(fps) > 0, f'No data found at {fp}' @@ -59,7 +55,7 @@ class OfflineTrainer(Trainer): _cfg.steps = _cfg.buffer_size self.buffer = Buffer(_cfg) for fp in tqdm(fps, desc='Loading data'): - td = torch.load(fp) + td = torch.load(fp, weights_only=False) assert td.shape[1] == _cfg.episode_length, \ f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \ f'please double-check your config.' @@ -68,6 +64,12 @@ class OfflineTrainer(Trainer): expected_episodes = _cfg.buffer_size // _cfg.episode_length assert self.buffer.num_eps == expected_episodes, \ f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.' + + def train(self): + """Train a TD-MPC2 agent.""" + assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \ + 'Offline training only supports multitask training with mt30 or mt80 task sets.' + self._load_dataset() print(f'Training agent for {self.cfg.steps} iterations...') metrics = {}