partial fix for loading old checkpoints
This commit is contained in:
@@ -83,7 +83,22 @@ class TDMPC2(torch.nn.Module):
|
||||
fp (str or dict): Filepath or state dict to load.
|
||||
"""
|
||||
state_dict = fp if isinstance(fp, dict) else torch.load(fp)
|
||||
self.model.load_state_dict(state_dict["model"])
|
||||
state_dict = state_dict["model"] if "model" in state_dict else state_dict
|
||||
try: # Checkpoints created AFTER Nov 10 update
|
||||
self.model.load_state_dict(state_dict)
|
||||
except: # Backwards compatibility
|
||||
def _get_submodule(state_dict, key):
|
||||
return {k.replace(f"_{key}.", ""): v for k, v in state_dict.items() if k.startswith(f"_{key}.")}
|
||||
for key in ["encoder", "dynamics", "reward", "pi"]:
|
||||
submodule_state_dict = _get_submodule(state_dict, key)
|
||||
getattr(self.model, f"_{key}").load_state_dict(submodule_state_dict)
|
||||
# Q-function requires special handling
|
||||
Qs_state_dict = _get_submodule(state_dict, "Qs")
|
||||
# TODO: figure out how to load Q-function state_dict from old checkpoints
|
||||
raise NotImplementedError("Backwards compatibility is currently broken for loading of old checkpoints, " \
|
||||
"please revert to a previous checkpoint, e.g. 88095e7899497cf7a1da36fb6bbb6bc7b5370d53 " \
|
||||
"until a fix is issued.")
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, obs, t0=False, eval_mode=False, task=None):
|
||||
|
||||
Reference in New Issue
Block a user