move cfg conversion to parser.py

This commit is contained in:
Nicklas Hansen
2024-10-31 14:52:59 -07:00
parent d477619f8d
commit b7725e74a5
2 changed files with 20 additions and 37 deletions

View File

@@ -1,5 +1,7 @@
import dataclasses
import re
from pathlib import Path
from typing import Any
import hydra
from omegaconf import OmegaConf
@@ -7,6 +9,23 @@ from omegaconf import OmegaConf
from common import MODEL_SIZE, TASK_SET
def cfg_to_dataclass(cfg, frozen=False):
"""
Converts an OmegaConf config to a dataclass object.
This prevents graph breaks when used with torch.compile.
"""
cfg_dict = OmegaConf.to_container(cfg)
fields = []
for key, value in cfg_dict.items():
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
dataclass_name = "Config"
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
def get(self, val, default=None):
return getattr(self, val, default)
dataclass.get = get
return dataclass()
def parse_cfg(cfg: OmegaConf) -> OmegaConf:
"""
Parses a Hydra config. Mostly for convenience.
@@ -58,4 +77,4 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
cfg.task_dim = 0
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
return cfg
return cfg_to_dataclass(cfg)

View File

@@ -18,42 +18,9 @@ from tdmpc2 import TDMPC2
from trainer.offline_trainer import OfflineTrainer
from trainer.online_trainer import OnlineTrainer
from common.logger import Logger
import dataclasses
from typing import Any
from omegaconf import OmegaConf
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('high')
def cfg_to_dataclass(cfg, frozen=False):
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
cfg_dict = OmegaConf.to_container(cfg)
fields = []
for key, value in cfg_dict.items():
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
# Create the dataclass
dataclass_name = "Config"
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
def get(self, val, default=None):
return getattr(self, val, default)
dataclass.get = get
return dataclass()
def cfg_to_dataclass(cfg, frozen=False):
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
cfg_dict = OmegaConf.to_container(cfg)
fields = []
for key, value in cfg_dict.items():
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
# Create the dataclass
dataclass_name = "Config"
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
def get(self, val, default=None):
return getattr(self, val, default)
dataclass.get = get
return dataclass()
@hydra.main(config_name='config', config_path='.')
def train(cfg: dict):
@@ -82,9 +49,6 @@ def train(cfg: dict):
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
cfg = cfg_to_dataclass(cfg)
trainer = trainer_cls(
cfg=cfg,
env=make_env(cfg),