Files
tdmpc2/tdmpc2/trainer/offline_trainer.py
2024-01-07 11:52:53 -08:00

105 lines
3.6 KiB
Python
Executable File

import os
from copy import deepcopy
from time import time
from pathlib import Path
from glob import glob
import numpy as np
import torch
from tqdm import tqdm
from common.buffer import Buffer
from trainer.base import Trainer
class OfflineTrainer(Trainer):
"""Trainer class for multi-task offline TD-MPC2 training."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._start_time = time()
def eval(self):
"""Evaluate a TD-MPC2 agent."""
results = dict()
for task_idx in tqdm(range(len(self.cfg.tasks)), desc='Evaluating'):
ep_rewards, ep_successes = [], []
for _ in range(self.cfg.eval_episodes):
obs, done, ep_reward, t = self.env.reset(task_idx), False, 0, 0
while not done:
action = self.agent.act(obs, t0=t==0, eval_mode=True, task=task_idx)
obs, reward, done, info = self.env.step(action)
ep_reward += reward
t += 1
ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
results.update({
f'episode_reward+{self.cfg.tasks[task_idx]}': np.nanmean(ep_rewards),
f'episode_success+{self.cfg.tasks[task_idx]}': np.nanmean(ep_successes),})
return results
def train(self):
"""Train a TD-MPC2 agent."""
assert self.cfg.multitask and self.cfg.task in {'mt30', 'mt80'}, \
'Offline training only supports multitask training with mt30 or mt80 task sets.'
# Load data
assert self.cfg.task in self.cfg.data_dir, \
f'Expected data directory {self.cfg.data_dir} to contain {self.cfg.task}, ' \
f'please double-check your config.'
fp = Path(os.path.join(self.cfg.data_dir, '*.pt'))
fps = sorted(glob(str(fp)))
assert len(fps) > 0, f'No data found at {fp}'
if self.cfg.rank == 0:
print(f'Found {len(fps)} files in {fp}')
# Distribute data across processes
assert len(fps) >= self.cfg.world_size, \
f'World size {self.cfg.world_size} cannot be greater than number of data chunks {len(fps)}'
fps = fps[self.cfg.rank::self.cfg.world_size]
print(f'Process {self.cfg.rank} has {len(fps)} files')
assert len(fps) > 0, f'No data assigned to process {self.cfg.rank}'
# Create buffer for sampling
_cfg = deepcopy(self.cfg)
_cfg.episode_length = 101 if self.cfg.task == 'mt80' else 501
_cfg.buffer_size = 550_450_000 if self.cfg.task == 'mt80' else 345_690_000
_cfg.buffer_size //= self.cfg.world_size
_cfg.steps = _cfg.buffer_size
self.buffer = Buffer(_cfg)
for fp in tqdm(fps, desc='Loading data'):
td = torch.load(fp)
assert td.shape[1] == _cfg.episode_length, \
f'Expected episode length {td.shape[1]} to match config episode length {_cfg.episode_length}, ' \
f'please double-check your config.'
for i in range(len(td)):
self.buffer.add(td[i])
if self.buffer.num_transitions > self.buffer.capacity:
print(f'Buffer has {self.buffer.num_transitions} transitions,' \
f'expected maximum {self.buffer.capacity} transitions in process {self.cfg.rank}.')
if self.cfg.rank == 0:
print(f'Training agent for {self.cfg.steps} iterations...')
metrics = {}
for i in range(self.cfg.steps):
# Update agent
train_metrics = self.agent.update(self.buffer)
# Evaluate agent periodically
if self.cfg.rank == 0 and (i % self.cfg.eval_freq == 0 or i % 10_000 == 0):
metrics = {
'iteration': i,
'total_time': time() - self._start_time,
}
metrics.update(train_metrics)
if i % self.cfg.eval_freq == 0:
metrics.update(self.eval())
self.logger.pprint_multitask(metrics, self.cfg)
if i > 0:
self.logger.save_agent(self.agent, identifier=f'{i}')
self.logger.log(metrics, 'pretrain')
if self.cfg.rank == 0:
self.logger.finish(self.agent)