auto-convert old checkpoints to new format

This commit is contained in:
Nicklas Hansen
2025-02-05 16:26:19 -08:00
parent dddc226d25
commit 5ced6dfeb4
2 changed files with 58 additions and 33 deletions

View File

@@ -4,6 +4,7 @@ import torch.nn.functional as F
from tensordict import from_modules
from copy import deepcopy
class Ensemble(nn.Module):
"""
Vectorized ensemble of modules.
@@ -161,3 +162,58 @@ def enc(cfg, out={}):
else:
raise NotImplementedError(f"Encoder for observation type {k} not implemented.")
return nn.ModuleDict(out)
def api_model_conversion(target_state_dict, source_state_dict):
"""
Converts a checkpoint from our old API to the new torch.compile compatible API.
"""
# check whether checkpoint is already in the new format
if "_detach_Qs_params.0.weight" in source_state_dict:
return source_state_dict
name_map = ['weight', 'bias', 'ln.weight', 'ln.bias']
new_state_dict = dict()
# rename keys
for key, val in list(source_state_dict.items()):
if key.startswith('_Qs.'):
num = key[len('_Qs.params.'):]
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
new_total_key = "_Qs.params." + new_key
del source_state_dict[key]
new_state_dict[new_total_key] = val
new_total_key = "_detach_Qs_params." + new_key
new_state_dict[new_total_key] = val
elif key.startswith('_target_Qs.'):
num = key[len('_target_Qs.params.'):]
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
new_total_key = "_target_Qs_params." + new_key
del source_state_dict[key]
new_state_dict[new_total_key] = val
# add batch_size and device from target_state_dict to new_state_dict
for prefix in ('_Qs.', '_detach_Qs_', '_target_Qs_'):
for key in ('__batch_size', '__device'):
new_key = prefix + 'params.' + key
new_state_dict[new_key] = target_state_dict[new_key]
# check that every key in new_state_dict is in target_state_dict
for key in new_state_dict.keys():
assert key in target_state_dict, f"key {key} not in target_state_dict"
# check that all Qs keys in target_state_dict are in new_state_dict
for key in target_state_dict.keys():
if 'Qs' in key:
assert key in new_state_dict, f"key {key} not in new_state_dict"
# check that source_state_dict contains no Qs keys
for key in source_state_dict.keys():
assert 'Qs' not in key, f"key {key} contains 'Qs'"
# copy log_std_min and log_std_max from target_state_dict to new_state_dict
new_state_dict['log_std_min'] = target_state_dict['log_std_min']
new_state_dict['log_std_dif'] = target_state_dict['log_std_dif']
# copy new_state_dict to source_state_dict
source_state_dict.update(new_state_dict)
return source_state_dict

View File

@@ -4,6 +4,7 @@ import torch.nn.functional as F
from common import math
from common.scale import RunningScale
from common.world_model import WorldModel
from common.layers import api_model_conversion
from tensordict import TensorDict
@@ -84,39 +85,7 @@ class TDMPC2(torch.nn.Module):
"""
state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=torch.get_default_device())
state_dict = state_dict["model"] if "model" in state_dict else state_dict
def load_sd_hook(model, local_state_dict, prefix, *args):
name_map = [
"weight", "bias", "ln.weight", "ln.bias",
]
# print("Listing state dict keys (from disk)")
# for k in list(local_state_dict.keys()):
# print("\t", k)
sd = model.state_dict()
# print("Listing dest state dict keys")
# for k in list(sd.keys()):
# print("\t", k)
# print("Maps:")
new_sd = dict(sd)
for cur_prefix in (prefix, "_target"+prefix[:-1]+"_"):
for key, val in list(local_state_dict.items()):
if not key.startswith(cur_prefix[:-1]):
continue
num = key[len(cur_prefix + "params."):]
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
new_total_key = cur_prefix + 'params.' + new_key
# print("\t", key, '-->', new_total_key)
del local_state_dict[key]
new_sd[new_total_key] = val
if not cur_prefix.startswith("_target"):
new_total_key = "_detach" + cur_prefix[:-1] + "_" + 'params.' + new_key
# print("\t", 'DETACH', key, '-->', new_total_key)
new_sd[new_total_key] = val
local_state_dict.update(new_sd)
return local_state_dict
load_sd_hook(self.model, state_dict, "_Qs.")
assert not set(TensorDict(self.model.state_dict()).keys()).symmetric_difference(set(TensorDict(state_dict).keys()))
state_dict = api_model_conversion(self.model.state_dict(), state_dict)
self.model.load_state_dict(state_dict)
return