Files
tdmpc2/tdmpc2/evaluate.py
2025-01-21 00:10:53 -08:00

104 lines
3.5 KiB
Python
Executable File

import os
os.environ['MUJOCO_GL'] = os.getenv("MUJOCO_GL", 'egl')
import warnings
warnings.filterwarnings('ignore')
import hydra
import imageio
import numpy as np
import torch
from termcolor import colored
from common.parser import parse_cfg
from common.seed import set_seed
from envs import make_env
from tdmpc2 import TDMPC2
torch.backends.cudnn.benchmark = True
@hydra.main(config_name='config', config_path='.')
def evaluate(cfg: dict):
"""
Script for evaluating a single-task / multi-task TD-MPC2 checkpoint.
Most relevant args:
`task`: task name (or mt30/mt80 for multi-task evaluation)
`model_size`: model size, must be one of `[1, 5, 19, 48, 317]` (default: 5)
`checkpoint`: path to model checkpoint to load
`eval_episodes`: number of episodes to evaluate on per task (default: 10)
`save_video`: whether to save a video of the evaluation (default: True)
`seed`: random seed (default: 1)
See config.yaml for a full list of args.
Example usage:
````
$ python evaluate.py task=mt80 model_size=48 checkpoint=/path/to/mt80-48M.pt
$ python evaluate.py task=mt30 model_size=317 checkpoint=/path/to/mt30-317M.pt
$ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true
```
"""
assert torch.cuda.is_available()
assert cfg.eval_episodes > 0, 'Must evaluate at least 1 episode.'
cfg = parse_cfg(cfg)
set_seed(cfg.seed)
print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold']))
print(colored(f'Model size: {cfg.get("model_size", "default")}', 'blue', attrs=['bold']))
print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold']))
if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint):
print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold']))
print(colored('To evaluate a multi-task model, use task=mt80 or task=mt30.', 'red', attrs=['bold']))
# Make environment
env = make_env(cfg)
# Load agent
agent = TDMPC2(cfg)
assert os.path.exists(cfg.checkpoint), f'Checkpoint {cfg.checkpoint} not found! Must be a valid filepath.'
agent.load(cfg.checkpoint)
# Evaluate
if cfg.multitask:
print(colored(f'Evaluating agent on {len(cfg.tasks)} tasks:', 'yellow', attrs=['bold']))
else:
print(colored(f'Evaluating agent on {cfg.task}:', 'yellow', attrs=['bold']))
if cfg.save_video:
video_dir = os.path.join(cfg.work_dir, 'videos')
os.makedirs(video_dir, exist_ok=True)
scores = []
tasks = cfg.tasks if cfg.multitask else [cfg.task]
for task_idx, task in enumerate(tasks):
if not cfg.multitask:
task_idx = None
ep_rewards, ep_successes = [], []
for i in range(cfg.eval_episodes):
obs, done, ep_reward, t = env.reset(task_idx=task_idx), False, 0, 0
if cfg.save_video:
frames = [env.render()]
while not done:
action = agent.act(obs, t0=t==0, task=task_idx)
obs, reward, done, info = env.step(action)
ep_reward += reward
t += 1
if cfg.save_video:
frames.append(env.render())
ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
if cfg.save_video:
imageio.mimsave(
os.path.join(video_dir, f'{task}-{i}.mp4'), frames, fps=15)
ep_rewards = np.mean(ep_rewards)
ep_successes = np.mean(ep_successes)
if cfg.multitask:
scores.append(ep_successes*100 if task.startswith('mw-') else ep_rewards/10)
print(colored(f' {task:<22}' \
f'\tR: {ep_rewards:.01f} ' \
f'\tS: {ep_successes:.02f}', 'yellow'))
if cfg.multitask:
print(colored(f'Normalized score: {np.mean(scores):.02f}', 'yellow', attrs=['bold']))
if __name__ == '__main__':
evaluate()