diff --git a/tdmpc2/common/layers.py b/tdmpc2/common/layers.py index dd34463..9c8ce57 100644 --- a/tdmpc2/common/layers.py +++ b/tdmpc2/common/layers.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from tensordict import from_modules from copy import deepcopy + class Ensemble(nn.Module): """ Vectorized ensemble of modules. @@ -161,3 +162,58 @@ def enc(cfg, out={}): else: raise NotImplementedError(f"Encoder for observation type {k} not implemented.") return nn.ModuleDict(out) + + +def api_model_conversion(target_state_dict, source_state_dict): + """ + Converts a checkpoint from our old API to the new torch.compile compatible API. + """ + # check whether checkpoint is already in the new format + if "_detach_Qs_params.0.weight" in source_state_dict: + return source_state_dict + + name_map = ['weight', 'bias', 'ln.weight', 'ln.bias'] + new_state_dict = dict() + + # rename keys + for key, val in list(source_state_dict.items()): + if key.startswith('_Qs.'): + num = key[len('_Qs.params.'):] + new_key = str(int(num) // 4) + "." + name_map[int(num) % 4] + new_total_key = "_Qs.params." + new_key + del source_state_dict[key] + new_state_dict[new_total_key] = val + new_total_key = "_detach_Qs_params." + new_key + new_state_dict[new_total_key] = val + elif key.startswith('_target_Qs.'): + num = key[len('_target_Qs.params.'):] + new_key = str(int(num) // 4) + "." + name_map[int(num) % 4] + new_total_key = "_target_Qs_params." + new_key + del source_state_dict[key] + new_state_dict[new_total_key] = val + + # add batch_size and device from target_state_dict to new_state_dict + for prefix in ('_Qs.', '_detach_Qs_', '_target_Qs_'): + for key in ('__batch_size', '__device'): + new_key = prefix + 'params.' + key + new_state_dict[new_key] = target_state_dict[new_key] + + # check that every key in new_state_dict is in target_state_dict + for key in new_state_dict.keys(): + assert key in target_state_dict, f"key {key} not in target_state_dict" + # check that all Qs keys in target_state_dict are in new_state_dict + for key in target_state_dict.keys(): + if 'Qs' in key: + assert key in new_state_dict, f"key {key} not in new_state_dict" + # check that source_state_dict contains no Qs keys + for key in source_state_dict.keys(): + assert 'Qs' not in key, f"key {key} contains 'Qs'" + + # copy log_std_min and log_std_max from target_state_dict to new_state_dict + new_state_dict['log_std_min'] = target_state_dict['log_std_min'] + new_state_dict['log_std_dif'] = target_state_dict['log_std_dif'] + + # copy new_state_dict to source_state_dict + source_state_dict.update(new_state_dict) + + return source_state_dict diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 8e39f08..a758fc7 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -4,6 +4,7 @@ import torch.nn.functional as F from common import math from common.scale import RunningScale from common.world_model import WorldModel +from common.layers import api_model_conversion from tensordict import TensorDict @@ -84,39 +85,7 @@ class TDMPC2(torch.nn.Module): """ state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=torch.get_default_device()) state_dict = state_dict["model"] if "model" in state_dict else state_dict - def load_sd_hook(model, local_state_dict, prefix, *args): - name_map = [ - "weight", "bias", "ln.weight", "ln.bias", - ] - # print("Listing state dict keys (from disk)") - # for k in list(local_state_dict.keys()): - # print("\t", k) - - sd = model.state_dict() - # print("Listing dest state dict keys") - # for k in list(sd.keys()): - # print("\t", k) - - # print("Maps:") - new_sd = dict(sd) - for cur_prefix in (prefix, "_target"+prefix[:-1]+"_"): - for key, val in list(local_state_dict.items()): - if not key.startswith(cur_prefix[:-1]): - continue - num = key[len(cur_prefix + "params."):] - new_key = str(int(num) // 4) + "." + name_map[int(num) % 4] - new_total_key = cur_prefix + 'params.' + new_key - # print("\t", key, '-->', new_total_key) - del local_state_dict[key] - new_sd[new_total_key] = val - if not cur_prefix.startswith("_target"): - new_total_key = "_detach" + cur_prefix[:-1] + "_" + 'params.' + new_key - # print("\t", 'DETACH', key, '-->', new_total_key) - new_sd[new_total_key] = val - local_state_dict.update(new_sd) - return local_state_dict - load_sd_hook(self.model, state_dict, "_Qs.") - assert not set(TensorDict(self.model.state_dict()).keys()).symmetric_difference(set(TensorDict(state_dict).keys())) + state_dict = api_model_conversion(self.model.state_dict(), state_dict) self.model.load_state_dict(state_dict) return