fix multitask model api conversion
This commit is contained in:
@@ -212,6 +212,7 @@ def api_model_conversion(target_state_dict, source_state_dict):
|
|||||||
# copy log_std_min and log_std_max from target_state_dict to new_state_dict
|
# copy log_std_min and log_std_max from target_state_dict to new_state_dict
|
||||||
new_state_dict['log_std_min'] = target_state_dict['log_std_min']
|
new_state_dict['log_std_min'] = target_state_dict['log_std_min']
|
||||||
new_state_dict['log_std_dif'] = target_state_dict['log_std_dif']
|
new_state_dict['log_std_dif'] = target_state_dict['log_std_dif']
|
||||||
|
new_state_dict['_action_masks'] = target_state_dict['_action_masks']
|
||||||
|
|
||||||
# copy new_state_dict to source_state_dict
|
# copy new_state_dict to source_state_dict
|
||||||
source_state_dict.update(new_state_dict)
|
source_state_dict.update(new_state_dict)
|
||||||
|
|||||||
@@ -83,7 +83,10 @@ class TDMPC2(torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
fp (str or dict): Filepath or state dict to load.
|
fp (str or dict): Filepath or state dict to load.
|
||||||
"""
|
"""
|
||||||
state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=torch.get_default_device())
|
if isinstance(fp, dict):
|
||||||
|
state_dict = fp
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(fp, map_location=torch.get_default_device(), weights_only=False)
|
||||||
state_dict = state_dict["model"] if "model" in state_dict else state_dict
|
state_dict = state_dict["model"] if "model" in state_dict else state_dict
|
||||||
state_dict = api_model_conversion(self.model.state_dict(), state_dict)
|
state_dict = api_model_conversion(self.model.state_dict(), state_dict)
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user