update offline trainer to use new torch.load api
This commit is contained in:
@@ -38,13 +38,9 @@ class OfflineTrainer(Trainer):
|
||||
f'episode_reward+{self.cfg.tasks[task_idx]}': np.nanmean(ep_rewards),
|
||||
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
|
||||
return results
|
||||
|
||||
def train(self):
|
||||
"""Train a TD-MPC2 agent."""
|
||||
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
|
||||
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
||||
|
||||
# Load data
|
||||
|
||||
def _load_dataset(self):
|
||||
"""Load dataset for offline training."""
|
||||
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
||||
fps = sorted(glob(str(fp)))
|
||||
assert len(fps) > 0, f'No data found at {fp}'
|
||||
@@ -59,7 +55,7 @@ class OfflineTrainer(Trainer):
|
||||
_cfg.steps = _cfg.buffer_size
|
||||
self.buffer = Buffer(_cfg)
|
||||
for fp in tqdm(fps, desc='Loading data'):
|
||||
td = torch.load(fp)
|
||||
td = torch.load(fp, weights_only=False)
|
||||
assert td.shape[1] == _cfg.episode_length, \
|
||||
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
||||
f'please double-check your config.'
|
||||
@@ -68,6 +64,12 @@ class OfflineTrainer(Trainer):
|
||||
expected_episodes = _cfg.buffer_size // _cfg.episode_length
|
||||
assert self.buffer.num_eps == expected_episodes, \
|
||||
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
|
||||
|
||||
def train(self):
|
||||
"""Train a TD-MPC2 agent."""
|
||||
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
|
||||
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
||||
self._load_dataset()
|
||||
|
||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||
metrics = {}
|
||||
|
||||
Reference in New Issue
Block a user