auto-convert old checkpoints to new format
This commit is contained in:
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
||||
from tensordict import from_modules
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class Ensemble(nn.Module):
|
||||
"""
|
||||
Vectorized ensemble of modules.
|
||||
@@ -161,3 +162,58 @@ def enc(cfg, out={}):
|
||||
else:
|
||||
raise NotImplementedError(f"Encoder for observation type {k} not implemented.")
|
||||
return nn.ModuleDict(out)
|
||||
|
||||
|
||||
def api_model_conversion(target_state_dict, source_state_dict):
|
||||
"""
|
||||
Converts a checkpoint from our old API to the new torch.compile compatible API.
|
||||
"""
|
||||
# check whether checkpoint is already in the new format
|
||||
if "_detach_Qs_params.0.weight" in source_state_dict:
|
||||
return source_state_dict
|
||||
|
||||
name_map = ['weight', 'bias', 'ln.weight', 'ln.bias']
|
||||
new_state_dict = dict()
|
||||
|
||||
# rename keys
|
||||
for key, val in list(source_state_dict.items()):
|
||||
if key.startswith('_Qs.'):
|
||||
num = key[len('_Qs.params.'):]
|
||||
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
|
||||
new_total_key = "_Qs.params." + new_key
|
||||
del source_state_dict[key]
|
||||
new_state_dict[new_total_key] = val
|
||||
new_total_key = "_detach_Qs_params." + new_key
|
||||
new_state_dict[new_total_key] = val
|
||||
elif key.startswith('_target_Qs.'):
|
||||
num = key[len('_target_Qs.params.'):]
|
||||
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
|
||||
new_total_key = "_target_Qs_params." + new_key
|
||||
del source_state_dict[key]
|
||||
new_state_dict[new_total_key] = val
|
||||
|
||||
# add batch_size and device from target_state_dict to new_state_dict
|
||||
for prefix in ('_Qs.', '_detach_Qs_', '_target_Qs_'):
|
||||
for key in ('__batch_size', '__device'):
|
||||
new_key = prefix + 'params.' + key
|
||||
new_state_dict[new_key] = target_state_dict[new_key]
|
||||
|
||||
# check that every key in new_state_dict is in target_state_dict
|
||||
for key in new_state_dict.keys():
|
||||
assert key in target_state_dict, f"key {key} not in target_state_dict"
|
||||
# check that all Qs keys in target_state_dict are in new_state_dict
|
||||
for key in target_state_dict.keys():
|
||||
if 'Qs' in key:
|
||||
assert key in new_state_dict, f"key {key} not in new_state_dict"
|
||||
# check that source_state_dict contains no Qs keys
|
||||
for key in source_state_dict.keys():
|
||||
assert 'Qs' not in key, f"key {key} contains 'Qs'"
|
||||
|
||||
# copy log_std_min and log_std_max from target_state_dict to new_state_dict
|
||||
new_state_dict['log_std_min'] = target_state_dict['log_std_min']
|
||||
new_state_dict['log_std_dif'] = target_state_dict['log_std_dif']
|
||||
|
||||
# copy new_state_dict to source_state_dict
|
||||
source_state_dict.update(new_state_dict)
|
||||
|
||||
return source_state_dict
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
||||
from common import math
|
||||
from common.scale import RunningScale
|
||||
from common.world_model import WorldModel
|
||||
from common.layers import api_model_conversion
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
||||
@@ -84,39 +85,7 @@ class TDMPC2(torch.nn.Module):
|
||||
"""
|
||||
state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=torch.get_default_device())
|
||||
state_dict = state_dict["model"] if "model" in state_dict else state_dict
|
||||
def load_sd_hook(model, local_state_dict, prefix, *args):
|
||||
name_map = [
|
||||
"weight", "bias", "ln.weight", "ln.bias",
|
||||
]
|
||||
# print("Listing state dict keys (from disk)")
|
||||
# for k in list(local_state_dict.keys()):
|
||||
# print("\t", k)
|
||||
|
||||
sd = model.state_dict()
|
||||
# print("Listing dest state dict keys")
|
||||
# for k in list(sd.keys()):
|
||||
# print("\t", k)
|
||||
|
||||
# print("Maps:")
|
||||
new_sd = dict(sd)
|
||||
for cur_prefix in (prefix, "_target"+prefix[:-1]+"_"):
|
||||
for key, val in list(local_state_dict.items()):
|
||||
if not key.startswith(cur_prefix[:-1]):
|
||||
continue
|
||||
num = key[len(cur_prefix + "params."):]
|
||||
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
|
||||
new_total_key = cur_prefix + 'params.' + new_key
|
||||
# print("\t", key, '-->', new_total_key)
|
||||
del local_state_dict[key]
|
||||
new_sd[new_total_key] = val
|
||||
if not cur_prefix.startswith("_target"):
|
||||
new_total_key = "_detach" + cur_prefix[:-1] + "_" + 'params.' + new_key
|
||||
# print("\t", 'DETACH', key, '-->', new_total_key)
|
||||
new_sd[new_total_key] = val
|
||||
local_state_dict.update(new_sd)
|
||||
return local_state_dict
|
||||
load_sd_hook(self.model, state_dict, "_Qs.")
|
||||
assert not set(TensorDict(self.model.state_dict()).keys()).symmetric_difference(set(TensorDict(state_dict).keys()))
|
||||
state_dict = api_model_conversion(self.model.state_dict(), state_dict)
|
||||
self.model.load_state_dict(state_dict)
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user