105 lines
3.6 KiB
Python
Executable File
105 lines
3.6 KiB
Python
Executable File
import os
|
|
from copy import deepcopy
|
|
from time import time
|
|
from pathlib import Path
|
|
from glob import glob
|
|
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from common.buffer import Buffer
|
|
from trainer.base import Trainer
|
|
|
|
|
|
class OfflineTrainer(Trainer):
|
|
"""Trainer class for multi-task offline TD-MPC2 training."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._start_time = time()
|
|
|
|
def eval(self):
|
|
"""Evaluate a TD-MPC2 agent."""
|
|
results = dict()
|
|
for task_idx in tqdm(range(len(self.cfg.tasks)), desc='Evaluating'):
|
|
ep_rewards, ep_successes = [], []
|
|
for _ in range(self.cfg.eval_episodes):
|
|
obs, done, ep_reward, t = self.env.reset(task_idx), False, 0, 0
|
|
while not done:
|
|
action = self.agent.act(obs, t0=t==0, eval_mode=True, task=task_idx)
|
|
obs, reward, done, info = self.env.step(action)
|
|
ep_reward += reward
|
|
t += 1
|
|
ep_rewards.append(ep_reward)
|
|
ep_successes.append(info['success'])
|
|
results.update({
|
|
f'episode_reward+{self.cfg.tasks[task_idx]}': np.nanmean(ep_rewards),
|
|
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
|
|
return results
|
|
|
|
def train(self):
|
|
"""Train a TD-MPC2 agent."""
|
|
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
|
|
'Offline training only supports multitask training with mt30 or mt80 task sets.'
|
|
|
|
# Load data
|
|
assert self.cfg.task in self.cfg.data_dir, \
|
|
f'Expected data directory {self.cfg.data_dir} to contain {self.cfg.task}, ' \
|
|
f'please double-check your config.'
|
|
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
|
|
fps = sorted(glob(str(fp)))
|
|
assert len(fps) > 0, f'No data found at {fp}'
|
|
if self.cfg.rank == 0:
|
|
print(f'Found {len(fps)} files in {fp}')
|
|
|
|
# Distribute data across processes
|
|
assert len(fps) >= self.cfg.world_size, \
|
|
f'World size {self.cfg.world_size} cannot be greater than number of data chunks {len(fps)}'
|
|
fps = fps[self.cfg.rank::self.cfg.world_size]
|
|
print(f'Process {self.cfg.rank} has {len(fps)} files')
|
|
assert len(fps) > 0, f'No data assigned to process {self.cfg.rank}'
|
|
|
|
# Create buffer for sampling
|
|
_cfg = deepcopy(self.cfg)
|
|
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
|
|
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
|
|
_cfg.buffer_size //= self.cfg.world_size
|
|
_cfg.steps = _cfg.buffer_size
|
|
self.buffer = Buffer(_cfg)
|
|
for fp in tqdm(fps, desc='Loading data'):
|
|
td = torch.load(fp)
|
|
assert td.shape[1] == _cfg.episode_length, \
|
|
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
|
|
f'please double-check your config.'
|
|
for i in range(len(td)):
|
|
self.buffer.add(td[i])
|
|
if self.buffer.num_transitions > self.buffer.capacity:
|
|
print(f'Buffer has {self.buffer.num_transitions} transitions,' \
|
|
f'expected maximum {self.buffer.capacity} transitions in process {self.cfg.rank}.')
|
|
|
|
if self.cfg.rank == 0:
|
|
print(f'Training agent for {self.cfg.steps} iterations...')
|
|
metrics = {}
|
|
for i in range(self.cfg.steps):
|
|
|
|
# Update agent
|
|
train_metrics = self.agent.update(self.buffer)
|
|
|
|
# Evaluate agent periodically
|
|
if self.cfg.rank == 0 and (i % self.cfg.eval_freq == 0 or i % 10_000 == 0):
|
|
metrics = {
|
|
'iteration': i,
|
|
'total_time': time() - self._start_time,
|
|
}
|
|
metrics.update(train_metrics)
|
|
if i % self.cfg.eval_freq == 0:
|
|
metrics.update(self.eval())
|
|
self.logger.pprint_multitask(metrics, self.cfg)
|
|
if i > 0:
|
|
self.logger.save_agent(self.agent, identifier=f'{i}')
|
|
self.logger.log(metrics, 'pretrain')
|
|
|
|
if self.cfg.rank == 0:
|
|
self.logger.finish(self.agent)
|