diff --git a/tdmpc2/common/layers.py b/tdmpc2/common/layers.py index 951d67a..97e977c 100644 --- a/tdmpc2/common/layers.py +++ b/tdmpc2/common/layers.py @@ -212,7 +212,8 @@ def api_model_conversion(target_state_dict, source_state_dict): # copy log_std_min and log_std_max from target_state_dict to new_state_dict new_state_dict['log_std_min'] = target_state_dict['log_std_min'] new_state_dict['log_std_dif'] = target_state_dict['log_std_dif'] - new_state_dict['_action_masks'] = target_state_dict['_action_masks'] + if '_action_masks' in target_state_dict: + new_state_dict['_action_masks'] = target_state_dict['_action_masks'] # copy new_state_dict to source_state_dict source_state_dict.update(new_state_dict)