Files
tdmpc2/tdmpc2/trainer/online_trainer.py
2024-10-21 14:49:21 -07:00

116 lines
3.2 KiB
Python
Executable File

from time import time
import numpy as np
import torch
from tensordict.tensordict import TensorDict
from trainer.base import Trainer
class OnlineTrainer(Trainer):
"""Trainer class for single-task online TD-MPC2 training."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._step = 0
self._ep_idx = 0
self._start_time = time()
def common_metrics(self):
"""Return a dictionary of current metrics."""
return dict(
step=self._step,
episode=self._ep_idx,
total_time=time() - self._start_time,
)
def eval(self):
"""Evaluate a TD-MPC2 agent."""
ep_rewards, ep_successes = [], []
for i in range(self.cfg.eval_episodes):
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
if self.cfg.save_video:
self.logger.video.init(self.env, enabled=(i==0))
while not done:
action = self.agent.act(obs, t0=t==0, eval_mode=True)
obs, reward, done, info = self.env.step(action)
ep_reward += reward
t += 1
if self.cfg.save_video:
self.logger.video.record(self.env)
ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
if self.cfg.save_video:
self.logger.video.save(self._step)
return dict(
episode_reward=np.nanmean(ep_rewards),
episode_success=np.nanmean(ep_successes),
)
def to_td(self, obs, action=None, reward=None):
"""Creates a TensorDict for a new episode."""
if isinstance(obs, dict):
obs = TensorDict(obs, batch_size=(), device='cpu')
else:
obs = obs.unsqueeze(0).cpu()
if action is None:
action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None:
reward = torch.tensor(float('nan'))
td = TensorDict(
obs=obs,
action=action.unsqueeze(0),
reward=reward.unsqueeze(0),
batch_size=(1,))
return td
def train(self):
"""Train a TD-MPC2 agent."""
train_metrics, done, eval_next = {}, True, False
while self._step <= self.cfg.steps:
# Evaluate agent periodically
if self._step % self.cfg.eval_freq == 0:
eval_next = True
# Reset environment
if done:
if eval_next:
eval_metrics = self.eval()
eval_metrics.update(self.common_metrics())
self.logger.log(eval_metrics, 'eval')
eval_next = False
if self._step > 0:
train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'],
)
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset()
self._tds = [self.to_td(obs)]
# Collect experience
if self._step > self.cfg.seed_steps:
action = self.agent.act(obs, t0=len(self._tds)==1)
else:
action = self.env.rand_act()
obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward))
# Update agent
if self._step >= self.cfg.seed_steps:
if self._step == self.cfg.seed_steps:
num_updates = self.cfg.seed_steps
print('Pretraining agent on seed data...')
else:
num_updates = 1
for _ in range(num_updates):
_train_metrics = self.agent.update(self.buffer)
train_metrics.update(_train_metrics)
self._step += 1
self.logger.finish(self.agent)