partial fix for loading old checkpoints

This commit is contained in:
Nicklas Hansen
2024-12-10 16:04:27 -08:00
parent 6117bc427d
commit 2e27fbb6f4

View File

@@ -83,7 +83,22 @@ class TDMPC2(torch.nn.Module):
fp (str or dict): Filepath or state dict to load.
"""
state_dict = fp if isinstance(fp, dict) else torch.load(fp)
self.model.load_state_dict(state_dict["model"])
state_dict = state_dict["model"] if "model" in state_dict else state_dict
try: # Checkpoints created AFTER Nov 10 update
self.model.load_state_dict(state_dict)
except: # Backwards compatibility
def _get_submodule(state_dict, key):
return {k.replace(f"_{key}.", ""): v for k, v in state_dict.items() if k.startswith(f"_{key}.")}
for key in ["encoder", "dynamics", "reward", "pi"]:
submodule_state_dict = _get_submodule(state_dict, key)
getattr(self.model, f"_{key}").load_state_dict(submodule_state_dict)
# Q-function requires special handling
Qs_state_dict = _get_submodule(state_dict, "Qs")
# TODO: figure out how to load Q-function state_dict from old checkpoints
raise NotImplementedError("Backwards compatibility is currently broken for loading of old checkpoints, " \
"please revert to a previous checkpoint, e.g. 88095e7899497cf7a1da36fb6bbb6bc7b5370d53 " \
"until a fix is issued.")
self.model.load_state_dict(state_dict)
@torch.no_grad()
def act(self, obs, t0=False, eval_mode=False, task=None):