Files
tdmpc2/tdmpc2/common/logger.py
2024-02-02 15:56:54 -08:00

241 lines
7.0 KiB
Python
Executable File

import os
import datetime
import re
import numpy as np
import pandas as pd
from termcolor import colored
from omegaconf import OmegaConf
from common import TASK_SET
CONSOLE_FORMAT = [
("iteration", "I", "int"),
("episode", "E", "int"),
("step", "I", "int"),
("episode_reward", "R", "float"),
("episode_success", "S", "float"),
("total_time", "T", "time"),
]
CAT_TO_COLOR = {
"pretrain": "yellow",
"train": "blue",
"eval": "green",
}
def make_dir(dir_path):
"""Create directory if it does not already exist."""
try:
os.makedirs(dir_path)
except OSError:
pass
return dir_path
def print_run(cfg):
"""
Pretty-printing of current run information.
Logger calls this method at initialization.
"""
prefix, color, attrs = " ", "green", ["bold"]
def _limstr(s, maxlen=36):
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
def _pprint(k, v):
print(
prefix + colored(f'{k.capitalize()+":":<15}', color, attrs=attrs), _limstr(v)
)
observations = ", ".join([str(v) for v in cfg.obs_shape.values()])
kvs = [
("task", cfg.task_title),
("steps", f"{int(cfg.steps):,}"),
("observations", observations),
("actions", cfg.action_dim),
("experiment", cfg.exp_name),
]
w = np.max([len(_limstr(str(kv[1]))) for kv in kvs]) + 25
div = "-" * w
print(div)
for k, v in kvs:
_pprint(k, v)
print(div)
def cfg_to_group(cfg, return_list=False):
"""
Return a wandb-safe group name for logging.
Optionally returns group name as list.
"""
lst = [cfg.task, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
return lst if return_list else "-".join(lst)
class VideoRecorder:
"""Utility class for logging evaluation videos."""
def __init__(self, cfg, wandb, fps=15):
self.cfg = cfg
self._save_dir = make_dir(cfg.work_dir / 'eval_video')
self._wandb = wandb
self.fps = fps
self.frames = []
self.enabled = False
def init(self, env, enabled=True):
self.frames = []
self.enabled = self._save_dir and self._wandb and enabled
self.record(env)
def record(self, env):
if self.enabled:
self.frames.append(env.render())
def save(self, step, key='videos/eval_video'):
if self.enabled and len(self.frames) > 0:
frames = np.stack(self.frames)
return self._wandb.log(
{key: self._wandb.Video(frames.transpose(0, 3, 1, 2), fps=self.fps, format='mp4')}, step=step
)
class Logger:
"""Primary logging object. Logs either locally or using wandb."""
def __init__(self, cfg):
self._log_dir = make_dir(cfg.work_dir)
self._model_dir = make_dir(self._log_dir / "models")
self._save_csv = cfg.save_csv
self._save_agent = cfg.save_agent
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._eval = []
print_run(cfg)
self.project = cfg.get("wandb_project", "none")
self.entity = cfg.get("wandb_entity", "none")
if cfg.disable_wandb or self.project == "none" or self.entity == "none":
print(colored("Wandb disabled.", "blue", attrs=["bold"]))
cfg.save_agent = False
cfg.save_video = False
self._wandb = None
self._video = None
return
os.environ["WANDB_SILENT"] = "true" if cfg.wandb_silent else "false"
import wandb
wandb.init(
project=self.project,
entity=self.entity,
name=str(cfg.seed),
group=self._group,
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
self._wandb = wandb
self._video = (
VideoRecorder(cfg, self._wandb)
if self._wandb and cfg.save_video
else None
)
@property
def video(self):
return self._video
@property
def model_dir(self):
return self._model_dir
def save_agent(self, agent=None, identifier='final'):
if self._save_agent and agent:
fp = self._model_dir / f'{str(identifier)}.pt'
agent.save(fp)
if self._wandb:
artifact = self._wandb.Artifact(
self._group + '-' + str(self._seed) + '-' + str(identifier),
type='model',
)
artifact.add_file(fp)
self._wandb.log_artifact(artifact)
def finish(self, agent=None):
try:
self.save_agent(agent)
except Exception as e:
print(colored(f"Failed to save model: {e}", "red"))
if self._wandb:
self._wandb.finish()
def _format(self, key, value, ty):
if ty == "int":
return f'{colored(key+":", "blue")} {int(value):,}'
elif ty == "float":
return f'{colored(key+":", "blue")} {value:.01f}'
elif ty == "time":
value = str(datetime.timedelta(seconds=int(value)))
return f'{colored(key+":", "blue")} {value}'
else:
raise f"invalid log format type: {ty}"
def _print(self, d, category):
category = colored(category, CAT_TO_COLOR[category])
pieces = [f" {category:<14}"]
for k, disp_k, ty in CONSOLE_FORMAT:
if k in d:
pieces.append(f"{self._format(disp_k, d[k], ty):<22}")
print(" ".join(pieces))
def pprint_multitask(self, d, cfg):
"""Pretty-print evaluation metrics for multi-task training."""
print(colored(f'Evaluated agent on {len(cfg.tasks)} tasks:', 'yellow', attrs=['bold']))
dmcontrol_reward = []
metaworld_reward = []
metaworld_success = []
for k, v in d.items():
if '+' not in k:
continue
task = k.split('+')[1]
if task in TASK_SET['mt30'] and k.startswith('episode_reward'): # DMControl
dmcontrol_reward.append(v)
print(colored(f' {task:<22}\tR: {v:.01f}', 'yellow'))
elif task in TASK_SET['mt80'] and task not in TASK_SET['mt30']: # Meta-World
if k.startswith('episode_reward'):
metaworld_reward.append(v)
elif k.startswith('episode_success'):
metaworld_success.append(v)
print(colored(f' {task:<22}\tS: {v:.02f}', 'yellow'))
dmcontrol_reward = np.nanmean(dmcontrol_reward)
d['episode_reward+avg_dmcontrol'] = dmcontrol_reward
print(colored(f' {"dmcontrol":<22}\tR: {dmcontrol_reward:.01f}', 'yellow', attrs=['bold']))
if cfg.task == 'mt80':
metaworld_reward = np.nanmean(metaworld_reward)
metaworld_success = np.nanmean(metaworld_success)
d['episode_reward+avg_metaworld'] = metaworld_reward
d['episode_success+avg_metaworld'] = metaworld_success
print(colored(f' {"metaworld":<22}\tR: {metaworld_reward:.01f}', 'yellow', attrs=['bold']))
print(colored(f' {"metaworld":<22}\tS: {metaworld_success:.02f}', 'yellow', attrs=['bold']))
def log(self, d, category="train"):
assert category in CAT_TO_COLOR.keys(), f"invalid category: {category}"
if self._wandb:
if category in {"train", "eval"}:
xkey = "step"
elif category == "pretrain":
xkey = "iteration"
_d = dict()
for k, v in d.items():
_d[category + "/" + k] = v
self._wandb.log(_d, step=d[xkey])
if category == "eval" and self._save_csv:
keys = ["step", "episode_reward"]
self._eval.append(np.array([d[keys[0]], d[keys[1]]]))
pd.DataFrame(np.array(self._eval)).to_csv(
self._log_dir / "eval.csv", header=keys, index=None
)
self._print(d, category)