move cfg conversion to parser.py
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
|
import dataclasses
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@@ -7,6 +9,23 @@ from omegaconf import OmegaConf
|
|||||||
from common import MODEL_SIZE, TASK_SET
|
from common import MODEL_SIZE, TASK_SET
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_to_dataclass(cfg, frozen=False):
|
||||||
|
"""
|
||||||
|
Converts an OmegaConf config to a dataclass object.
|
||||||
|
This prevents graph breaks when used with torch.compile.
|
||||||
|
"""
|
||||||
|
cfg_dict = OmegaConf.to_container(cfg)
|
||||||
|
fields = []
|
||||||
|
for key, value in cfg_dict.items():
|
||||||
|
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
||||||
|
dataclass_name = "Config"
|
||||||
|
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
||||||
|
def get(self, val, default=None):
|
||||||
|
return getattr(self, val, default)
|
||||||
|
dataclass.get = get
|
||||||
|
return dataclass()
|
||||||
|
|
||||||
|
|
||||||
def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
||||||
"""
|
"""
|
||||||
Parses a Hydra config. Mostly for convenience.
|
Parses a Hydra config. Mostly for convenience.
|
||||||
@@ -58,4 +77,4 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
|||||||
cfg.task_dim = 0
|
cfg.task_dim = 0
|
||||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||||
|
|
||||||
return cfg
|
return cfg_to_dataclass(cfg)
|
||||||
|
|||||||
@@ -18,42 +18,9 @@ from tdmpc2 import TDMPC2
|
|||||||
from trainer.offline_trainer import OfflineTrainer
|
from trainer.offline_trainer import OfflineTrainer
|
||||||
from trainer.online_trainer import OnlineTrainer
|
from trainer.online_trainer import OnlineTrainer
|
||||||
from common.logger import Logger
|
from common.logger import Logger
|
||||||
import dataclasses
|
|
||||||
from typing import Any
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
torch.set_float32_matmul_precision('high')
|
torch.set_float32_matmul_precision('high')
|
||||||
|
|
||||||
def cfg_to_dataclass(cfg, frozen=False):
|
|
||||||
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
|
||||||
cfg_dict = OmegaConf.to_container(cfg)
|
|
||||||
fields = []
|
|
||||||
for key, value in cfg_dict.items():
|
|
||||||
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
|
||||||
|
|
||||||
# Create the dataclass
|
|
||||||
dataclass_name = "Config"
|
|
||||||
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
|
||||||
def get(self, val, default=None):
|
|
||||||
return getattr(self, val, default)
|
|
||||||
dataclass.get = get
|
|
||||||
return dataclass()
|
|
||||||
|
|
||||||
def cfg_to_dataclass(cfg, frozen=False):
|
|
||||||
# Converts an OmegaConf config to a dataclass, which will not cause graph breaks
|
|
||||||
cfg_dict = OmegaConf.to_container(cfg)
|
|
||||||
fields = []
|
|
||||||
for key, value in cfg_dict.items():
|
|
||||||
fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_)))
|
|
||||||
|
|
||||||
# Create the dataclass
|
|
||||||
dataclass_name = "Config"
|
|
||||||
dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen)
|
|
||||||
def get(self, val, default=None):
|
|
||||||
return getattr(self, val, default)
|
|
||||||
dataclass.get = get
|
|
||||||
return dataclass()
|
|
||||||
|
|
||||||
@hydra.main(config_name='config', config_path='.')
|
@hydra.main(config_name='config', config_path='.')
|
||||||
def train(cfg: dict):
|
def train(cfg: dict):
|
||||||
@@ -82,9 +49,6 @@ def train(cfg: dict):
|
|||||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||||
|
|
||||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||||
|
|
||||||
cfg = cfg_to_dataclass(cfg)
|
|
||||||
|
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
cfg=cfg,
|
cfg=cfg,
|
||||||
env=make_env(cfg),
|
env=make_env(cfg),
|
||||||
|
|||||||
Reference in New Issue
Block a user