partial fix for loading old checkpoints
This commit is contained in:
@@ -83,7 +83,22 @@ class TDMPC2(torch.nn.Module):
|
|||||||
fp (str or dict): Filepath or state dict to load.
|
fp (str or dict): Filepath or state dict to load.
|
||||||
"""
|
"""
|
||||||
state_dict = fp if isinstance(fp, dict) else torch.load(fp)
|
state_dict = fp if isinstance(fp, dict) else torch.load(fp)
|
||||||
self.model.load_state_dict(state_dict["model"])
|
state_dict = state_dict["model"] if "model" in state_dict else state_dict
|
||||||
|
try: # Checkpoints created AFTER Nov 10 update
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
except: # Backwards compatibility
|
||||||
|
def _get_submodule(state_dict, key):
|
||||||
|
return {k.replace(f"_{key}.", ""): v for k, v in state_dict.items() if k.startswith(f"_{key}.")}
|
||||||
|
for key in ["encoder", "dynamics", "reward", "pi"]:
|
||||||
|
submodule_state_dict = _get_submodule(state_dict, key)
|
||||||
|
getattr(self.model, f"_{key}").load_state_dict(submodule_state_dict)
|
||||||
|
# Q-function requires special handling
|
||||||
|
Qs_state_dict = _get_submodule(state_dict, "Qs")
|
||||||
|
# TODO: figure out how to load Q-function state_dict from old checkpoints
|
||||||
|
raise NotImplementedError("Backwards compatibility is currently broken for loading of old checkpoints, " \
|
||||||
|
"please revert to a previous checkpoint, e.g. 88095e7899497cf7a1da36fb6bbb6bc7b5370d53 " \
|
||||||
|
"until a fix is issued.")
|
||||||
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def act(self, obs, t0=False, eval_mode=False, task=None):
|
def act(self, obs, t0=False, eval_mode=False, task=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user