From b7725e74a5320b3046bf604857d983737ed49393 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Thu, 31 Oct 2024 14:52:59 -0700 Subject: [PATCH] move cfg conversion to parser.py --- tdmpc2/common/parser.py | 21 ++++++++++++++++++++- tdmpc2/train.py | 36 ------------------------------------ 2 files changed, 20 insertions(+), 37 deletions(-) diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index 378ba4a..a8d9f25 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -1,5 +1,7 @@ +import dataclasses import re from pathlib import Path +from typing import Any import hydra from omegaconf import OmegaConf @@ -7,6 +9,23 @@ from omegaconf import OmegaConf from common import MODEL_SIZE, TASK_SET +def cfg_to_dataclass(cfg, frozen=False): + """ + Converts an OmegaConf config to a dataclass object. + This prevents graph breaks when used with torch.compile. + """ + cfg_dict = OmegaConf.to_container(cfg) + fields = [] + for key, value in cfg_dict.items(): + fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_))) + dataclass_name = "Config" + dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen) + def get(self, val, default=None): + return getattr(self, val, default) + dataclass.get = get + return dataclass() + + def parse_cfg(cfg: OmegaConf) -> OmegaConf: """ Parses a Hydra config. Mostly for convenience. @@ -58,4 +77,4 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.task_dim = 0 cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) - return cfg + return cfg_to_dataclass(cfg) diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 48206ec..3dc37a6 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -18,42 +18,9 @@ from tdmpc2 import TDMPC2 from trainer.offline_trainer import OfflineTrainer from trainer.online_trainer import OnlineTrainer from common.logger import Logger -import dataclasses -from typing import Any -from omegaconf import OmegaConf torch.backends.cudnn.benchmark = True - torch.set_float32_matmul_precision('high') -def cfg_to_dataclass(cfg, frozen=False): - # Converts an OmegaConf config to a dataclass, which will not cause graph breaks - cfg_dict = OmegaConf.to_container(cfg) - fields = [] - for key, value in cfg_dict.items(): - fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_))) - - # Create the dataclass - dataclass_name = "Config" - dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen) - def get(self, val, default=None): - return getattr(self, val, default) - dataclass.get = get - return dataclass() - -def cfg_to_dataclass(cfg, frozen=False): - # Converts an OmegaConf config to a dataclass, which will not cause graph breaks - cfg_dict = OmegaConf.to_container(cfg) - fields = [] - for key, value in cfg_dict.items(): - fields.append((key, Any, dataclasses.field(default_factory=lambda value_=value: value_))) - - # Create the dataclass - dataclass_name = "Config" - dataclass = dataclasses.make_dataclass(dataclass_name, fields, frozen=frozen) - def get(self, val, default=None): - return getattr(self, val, default) - dataclass.get = get - return dataclass() @hydra.main(config_name='config', config_path='.') def train(cfg: dict): @@ -82,9 +49,6 @@ def train(cfg: dict): print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer - - cfg = cfg_to_dataclass(cfg) - trainer = trainer_cls( cfg=cfg, env=make_env(cfg),