easier customization of architecture:
all args can now be set freely when model_size is not specified
This commit is contained in:
@@ -40,6 +40,7 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
|||||||
cfg.bin_size = (cfg.vmax - cfg.vmin) / (cfg.num_bins-1) # Bin size for discrete regression
|
cfg.bin_size = (cfg.vmax - cfg.vmin) / (cfg.num_bins-1) # Bin size for discrete regression
|
||||||
|
|
||||||
# Model size
|
# Model size
|
||||||
|
if cfg.get('model_size', None) is not None:
|
||||||
assert cfg.model_size in MODEL_SIZE.keys(), \
|
assert cfg.model_size in MODEL_SIZE.keys(), \
|
||||||
f'Invalid model size {cfg.model_size}. Must be one of {list(MODEL_SIZE.keys())}'
|
f'Invalid model size {cfg.model_size}. Must be one of {list(MODEL_SIZE.keys())}'
|
||||||
for k, v in MODEL_SIZE[cfg.model_size].items():
|
for k, v in MODEL_SIZE[cfg.model_size].items():
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ vmin: -10
|
|||||||
vmax: +10
|
vmax: +10
|
||||||
|
|
||||||
# architecture
|
# architecture
|
||||||
model_size: 5
|
model_size: ???
|
||||||
num_enc_layers: 2
|
num_enc_layers: 2
|
||||||
enc_dim: 256
|
enc_dim: 256
|
||||||
mlp_dim: 512
|
mlp_dim: 512
|
||||||
|
|||||||
Reference in New Issue
Block a user