update offline trainer to use new torch.load api

This commit is contained in:
Nicklas Hansen
2024-12-10 16:30:05 -08:00
parent 2e27fbb6f4
commit df8a465c8e

View File

@@ -38,13 +38,9 @@ class OfflineTrainer(Trainer):
f'episode_reward+{self.cfg.tasks[task_idx]}': np.nanmean(ep_rewards),
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
return results
def train(self):
"""Train a TD-MPC2 agent."""
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
'Offline training only supports multitask training with mt30 or mt80 task sets.'
# Load data
def _load_dataset(self):
"""Load dataset for offline training."""
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}'
@@ -59,7 +55,7 @@ class OfflineTrainer(Trainer):
_cfg.steps = _cfg.buffer_size
self.buffer = Buffer(_cfg)
for fp in tqdm(fps, desc='Loading data'):
td = torch.load(fp)
td = torch.load(fp, weights_only=False)
assert td.shape[1] == _cfg.episode_length, \
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
f'please double-check your config.'
@@ -68,6 +64,12 @@ class OfflineTrainer(Trainer):
expected_episodes = _cfg.buffer_size // _cfg.episode_length
assert self.buffer.num_eps == expected_episodes, \
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
def train(self):
"""Train a TD-MPC2 agent."""
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
'Offline training only supports multitask training with mt30 or mt80 task sets.'
self._load_dataset()
print(f'Training agent for {self.cfg.steps} iterations...')
metrics = {}