116 lines
3.2 KiB
Python
Executable File
116 lines
3.2 KiB
Python
Executable File
from time import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
from tensordict.tensordict import TensorDict
|
|
from trainer.base import Trainer
|
|
|
|
|
|
class OnlineTrainer(Trainer):
|
|
"""Trainer class for single-task online TD-MPC2 training."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._step = 0
|
|
self._ep_idx = 0
|
|
self._start_time = time()
|
|
|
|
def common_metrics(self):
|
|
"""Return a dictionary of current metrics."""
|
|
return dict(
|
|
step=self._step,
|
|
episode=self._ep_idx,
|
|
total_time=time() - self._start_time,
|
|
)
|
|
|
|
def eval(self):
|
|
"""Evaluate a TD-MPC2 agent."""
|
|
ep_rewards, ep_successes = [], []
|
|
for i in range(self.cfg.eval_episodes):
|
|
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
|
|
if self.cfg.save_video:
|
|
self.logger.video.init(self.env, enabled=(i==0))
|
|
while not done:
|
|
action = self.agent.act(obs, t0=t==0, eval_mode=True)
|
|
obs, reward, done, info = self.env.step(action)
|
|
ep_reward += reward
|
|
t += 1
|
|
if self.cfg.save_video:
|
|
self.logger.video.record(self.env)
|
|
ep_rewards.append(ep_reward)
|
|
ep_successes.append(info['success'])
|
|
if self.cfg.save_video:
|
|
self.logger.video.save(self._step)
|
|
return dict(
|
|
episode_reward=np.nanmean(ep_rewards),
|
|
episode_success=np.nanmean(ep_successes),
|
|
)
|
|
|
|
def to_td(self, obs, action=None, reward=None):
|
|
"""Creates a TensorDict for a new episode."""
|
|
if isinstance(obs, dict):
|
|
obs = TensorDict(obs, batch_size=(), device='cpu')
|
|
else:
|
|
obs = obs.unsqueeze(0).cpu()
|
|
if action is None:
|
|
action = torch.full_like(self.env.rand_act(), float('nan'))
|
|
if reward is None:
|
|
reward = torch.tensor(float('nan'))
|
|
td = TensorDict(
|
|
obs=obs,
|
|
action=action.unsqueeze(0),
|
|
reward=reward.unsqueeze(0),
|
|
batch_size=(1,))
|
|
return td
|
|
|
|
def train(self):
|
|
"""Train a TD-MPC2 agent."""
|
|
train_metrics, done, eval_next = {}, True, False
|
|
while self._step <= self.cfg.steps:
|
|
# Evaluate agent periodically
|
|
if self._step % self.cfg.eval_freq == 0:
|
|
eval_next = True
|
|
|
|
# Reset environment
|
|
if done:
|
|
if eval_next:
|
|
eval_metrics = self.eval()
|
|
eval_metrics.update(self.common_metrics())
|
|
self.logger.log(eval_metrics, 'eval')
|
|
eval_next = False
|
|
|
|
if self._step > 0:
|
|
train_metrics.update(
|
|
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
|
episode_success=info['success'],
|
|
)
|
|
train_metrics.update(self.common_metrics())
|
|
self.logger.log(train_metrics, 'train')
|
|
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
|
|
|
obs = self.env.reset()
|
|
self._tds = [self.to_td(obs)]
|
|
|
|
# Collect experience
|
|
if self._step > self.cfg.seed_steps:
|
|
action = self.agent.act(obs, t0=len(self._tds)==1)
|
|
else:
|
|
action = self.env.rand_act()
|
|
obs, reward, done, info = self.env.step(action)
|
|
self._tds.append(self.to_td(obs, action, reward))
|
|
|
|
# Update agent
|
|
if self._step >= self.cfg.seed_steps:
|
|
if self._step == self.cfg.seed_steps:
|
|
num_updates = self.cfg.seed_steps
|
|
print('Pretraining agent on seed data...')
|
|
else:
|
|
num_updates = 1
|
|
for _ in range(num_updates):
|
|
_train_metrics = self.agent.update(self.buffer)
|
|
train_metrics.update(_train_metrics)
|
|
|
|
self._step += 1
|
|
|
|
self.logger.finish(self.agent)
|