From 445af9d81d9f459ebeec4f43995ede2ee573e1fd Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Wed, 6 Dec 2023 08:15:54 -0800 Subject: [PATCH] easier customization of architecture: all args can now be set freely when model_size is not specified --- tdmpc2/common/parser.py | 13 +++++++------ tdmpc2/config.yaml | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index f36731e..ddce2b4 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -40,12 +40,13 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.bin_size = (cfg.vmax - cfg.vmin) / (cfg.num_bins-1) # Bin size for discrete regression # Model size - assert cfg.model_size in MODEL_SIZE.keys(), \ - f'Invalid model size {cfg.model_size}. Must be one of {list(MODEL_SIZE.keys())}' - for k, v in MODEL_SIZE[cfg.model_size].items(): - cfg[k] = v - if cfg.task == 'mt30' and cfg.model_size == 19: - cfg.latent_dim = 512 # This checkpoint is slightly smaller + if cfg.get('model_size', None) is not None: + assert cfg.model_size in MODEL_SIZE.keys(), \ + f'Invalid model size {cfg.model_size}. Must be one of {list(MODEL_SIZE.keys())}' + for k, v in MODEL_SIZE[cfg.model_size].items(): + cfg[k] = v + if cfg.task == 'mt30' and cfg.model_size == 19: + cfg.latent_dim = 512 # This checkpoint is slightly smaller # Multi-task cfg.multitask = cfg.task in TASK_SET.keys() diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index 3b945ee..b625bf5 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -49,7 +49,7 @@ vmin: -10 vmax: +10 # architecture -model_size: 5 +model_size: ??? num_enc_layers: 2 enc_dim: 256 mlp_dim: 512