diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 12c2832..e4f56bf 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -83,7 +83,22 @@ class TDMPC2(torch.nn.Module): fp (str or dict): Filepath or state dict to load. """ state_dict = fp if isinstance(fp, dict) else torch.load(fp) - self.model.load_state_dict(state_dict["model"]) + state_dict = state_dict["model"] if "model" in state_dict else state_dict + try: # Checkpoints created AFTER Nov 10 update + self.model.load_state_dict(state_dict) + except: # Backwards compatibility + def _get_submodule(state_dict, key): + return {k.replace(f"_{key}.", ""): v for k, v in state_dict.items() if k.startswith(f"_{key}.")} + for key in ["encoder", "dynamics", "reward", "pi"]: + submodule_state_dict = _get_submodule(state_dict, key) + getattr(self.model, f"_{key}").load_state_dict(submodule_state_dict) + # Q-function requires special handling + Qs_state_dict = _get_submodule(state_dict, "Qs") + # TODO: figure out how to load Q-function state_dict from old checkpoints + raise NotImplementedError("Backwards compatibility is currently broken for loading of old checkpoints, " \ + "please revert to a previous checkpoint, e.g. 88095e7899497cf7a1da36fb6bbb6bc7b5370d53 " \ + "until a fix is issued.") + self.model.load_state_dict(state_dict) @torch.no_grad() def act(self, obs, t0=False, eval_mode=False, task=None):