From 0a914570dcc5383cad7b3e0f544dbf6fa4f56680 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Thu, 27 Feb 2025 16:25:21 -0800 Subject: [PATCH] fix multitask model api conversion --- tdmpc2/common/layers.py | 1 + tdmpc2/tdmpc2.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tdmpc2/common/layers.py b/tdmpc2/common/layers.py index 9c8ce57..951d67a 100644 --- a/tdmpc2/common/layers.py +++ b/tdmpc2/common/layers.py @@ -212,6 +212,7 @@ def api_model_conversion(target_state_dict, source_state_dict): # copy log_std_min and log_std_max from target_state_dict to new_state_dict new_state_dict['log_std_min'] = target_state_dict['log_std_min'] new_state_dict['log_std_dif'] = target_state_dict['log_std_dif'] + new_state_dict['_action_masks'] = target_state_dict['_action_masks'] # copy new_state_dict to source_state_dict source_state_dict.update(new_state_dict) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index a758fc7..a59216a 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -83,7 +83,10 @@ class TDMPC2(torch.nn.Module): Args: fp (str or dict): Filepath or state dict to load. """ - state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=torch.get_default_device()) + if isinstance(fp, dict): + state_dict = fp + else: + state_dict = torch.load(fp, map_location=torch.get_default_device(), weights_only=False) state_dict = state_dict["model"] if "model" in state_dict else state_dict state_dict = api_model_conversion(self.model.state_dict(), state_dict) self.model.load_state_dict(state_dict)