diff --git a/tdmpc2/evaluate.py b/tdmpc2/evaluate.py index f5b8628..a9f04ea 100755 --- a/tdmpc2/evaluate.py +++ b/tdmpc2/evaluate.py @@ -44,7 +44,7 @@ def evaluate(cfg: dict): cfg = parse_cfg(cfg) set_seed(cfg.seed) print(colored(f'Task: {cfg.task}', 'blue', attrs=['bold'])) - print(colored(f'Model size: {cfg.model_size}', 'blue', attrs=['bold'])) + print(colored(f'Model size: {cfg.get("model_size", "default")}', 'blue', attrs=['bold'])) print(colored(f'Checkpoint: {cfg.checkpoint}', 'blue', attrs=['bold'])) if not cfg.multitask and ('mt80' in cfg.checkpoint or 'mt30' in cfg.checkpoint): print(colored('Warning: single-task evaluation of multi-task models is not currently supported.', 'red', attrs=['bold']))