update offline trainer to use new torch.load api
This commit is contained in:
@@ -39,12 +39,8 @@ class OfflineTrainer(Trainer):
|
|||||||
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
|
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def train(self):
|
def _load_dataset(self):
|
||||||
"""Train a TD-MPC2 agent."""
|
"""Load dataset for offline training."""
|
||||||
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
|
|
||||||
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
|
||||||
|
|
||||||
# Load data
|
|
||||||
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
||||||
fps = sorted(glob(str(fp)))
|
fps = sorted(glob(str(fp)))
|
||||||
assert len(fps) > 0, f'No data found at {fp}'
|
assert len(fps) > 0, f'No data found at {fp}'
|
||||||
@@ -59,7 +55,7 @@ class OfflineTrainer(Trainer):
|
|||||||
_cfg.steps = _cfg.buffer_size
|
_cfg.steps = _cfg.buffer_size
|
||||||
self.buffer = Buffer(_cfg)
|
self.buffer = Buffer(_cfg)
|
||||||
for fp in tqdm(fps, desc='Loading data'):
|
for fp in tqdm(fps, desc='Loading data'):
|
||||||
td = torch.load(fp)
|
td = torch.load(fp, weights_only=False)
|
||||||
assert td.shape[1] == _cfg.episode_length, \
|
assert td.shape[1] == _cfg.episode_length, \
|
||||||
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
||||||
f'please double-check your config.'
|
f'please double-check your config.'
|
||||||
@@ -69,6 +65,12 @@ class OfflineTrainer(Trainer):
|
|||||||
assert self.buffer.num_eps == expected_episodes, \
|
assert self.buffer.num_eps == expected_episodes, \
|
||||||
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
|
f'Buffer has {self.buffer.num_eps} episodes, expected {expected_episodes} episodes.'
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""Train a TD-MPC2 agent."""
|
||||||
|
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
|
||||||
|
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
||||||
|
self._load_dataset()
|
||||||
|
|
||||||
print(f'Training agent for {self.cfg.steps} iterations...')
|
print(f'Training agent for {self.cfg.steps} iterations...')
|
||||||
metrics = {}
|
metrics = {}
|
||||||
for i in range(self.cfg.steps):
|
for i in range(self.cfg.steps):
|
||||||
|
|||||||
Reference in New Issue
Block a user