move cfg conversion to parser.py
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import dataclasses
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
@@ -7,6 +9,23 @@ from omegaconf import OmegaConf
|
||||
from common import MODEL_SIZE, TASK_SET
|
||||
|
||||
|
||||
def cfg_to_dataclass(cfg, frozen=False):
|
||||
"""
|
||||
Converts an OmegaConf config to a dataclass object.
|
||||
This prevents graph breaks when used with torch.compile.
|
||||
"""
|
||||
cfg_dict = OmegaConf.to_container(cfg)
|
||||
fields = []
|
||||
for key, value in cfg_dict.items():
|
||||
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
||||
dataclass_name = "Config"
|
||||
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
||||
def get(self, val, default=None):
|
||||
return getattr(self, val, default)
|
||||
dataclass.get = get
|
||||
return dataclass()
|
||||
|
||||
|
||||
def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
||||
"""
|
||||
Parses a Hydra config. Mostly for convenience.
|
||||
@@ -58,4 +77,4 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
||||
cfg.task_dim = 0
|
||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||
|
||||
return cfg
|
||||
return cfg_to_dataclass(cfg)
|
||||
|
||||
@@ -18,42 +18,9 @@ from tdmpc2 import TDMPC2
|
||||
from trainer.offline_trainer import OfflineTrainer
|
||||
from trainer.online_trainer import OnlineTrainer
|
||||
from common.logger import Logger
|
||||
import dataclasses
|
||||
from typing import Any
|
||||
from omegaconf import OmegaConf
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
def cfg_to_dataclass(cfg, frozen=False):
|
||||
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
||||
cfg_dict = OmegaConf.to_container(cfg)
|
||||
fields = []
|
||||
for key, value in cfg_dict.items():
|
||||
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
||||
|
||||
# Create the dataclass
|
||||
dataclass_name = "Config"
|
||||
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
||||
def get(self, val, default=None):
|
||||
return getattr(self, val, default)
|
||||
dataclass.get = get
|
||||
return dataclass()
|
||||
|
||||
def cfg_to_dataclass(cfg, frozen=False):
|
||||
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
||||
cfg_dict = OmegaConf.to_container(cfg)
|
||||
fields = []
|
||||
for key, value in cfg_dict.items():
|
||||
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
||||
|
||||
# Create the dataclass
|
||||
dataclass_name = "Config"
|
||||
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
||||
def get(self, val, default=None):
|
||||
return getattr(self, val, default)
|
||||
dataclass.get = get
|
||||
return dataclass()
|
||||
|
||||
@hydra.main(config_name='config', config_path='.')
|
||||
def train(cfg: dict):
|
||||
@@ -82,9 +49,6 @@ def train(cfg: dict):
|
||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||
|
||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||
|
||||
cfg = cfg_to_dataclass(cfg)
|
||||
|
||||
trainer = trainer_cls(
|
||||
cfg=cfg,
|
||||
env=make_env(cfg),
|
||||
|
||||
Reference in New Issue
Block a user