auto-convert old checkpoints to new format
This commit is contained in:
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|||||||
from tensordict import from_modules
|
from tensordict import from_modules
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
class Ensemble(nn.Module):
|
class Ensemble(nn.Module):
|
||||||
"""
|
"""
|
||||||
Vectorized ensemble of modules.
|
Vectorized ensemble of modules.
|
||||||
@@ -161,3 +162,58 @@ def enc(cfg, out={}):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Encoder for observation type {k} not implemented.")
|
raise NotImplementedError(f"Encoder for observation type {k} not implemented.")
|
||||||
return nn.ModuleDict(out)
|
return nn.ModuleDict(out)
|
||||||
|
|
||||||
|
|
||||||
|
def api_model_conversion(target_state_dict, source_state_dict):
|
||||||
|
"""
|
||||||
|
Converts a checkpoint from our old API to the new torch.compile compatible API.
|
||||||
|
"""
|
||||||
|
# check whether checkpoint is already in the new format
|
||||||
|
if "_detach_Qs_params.0.weight" in source_state_dict:
|
||||||
|
return source_state_dict
|
||||||
|
|
||||||
|
name_map = ['weight', 'bias', 'ln.weight', 'ln.bias']
|
||||||
|
new_state_dict = dict()
|
||||||
|
|
||||||
|
# rename keys
|
||||||
|
for key, val in list(source_state_dict.items()):
|
||||||
|
if key.startswith('_Qs.'):
|
||||||
|
num = key[len('_Qs.params.'):]
|
||||||
|
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
|
||||||
|
new_total_key = "_Qs.params." + new_key
|
||||||
|
del source_state_dict[key]
|
||||||
|
new_state_dict[new_total_key] = val
|
||||||
|
new_total_key = "_detach_Qs_params." + new_key
|
||||||
|
new_state_dict[new_total_key] = val
|
||||||
|
elif key.startswith('_target_Qs.'):
|
||||||
|
num = key[len('_target_Qs.params.'):]
|
||||||
|
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
|
||||||
|
new_total_key = "_target_Qs_params." + new_key
|
||||||
|
del source_state_dict[key]
|
||||||
|
new_state_dict[new_total_key] = val
|
||||||
|
|
||||||
|
# add batch_size and device from target_state_dict to new_state_dict
|
||||||
|
for prefix in ('_Qs.', '_detach_Qs_', '_target_Qs_'):
|
||||||
|
for key in ('__batch_size', '__device'):
|
||||||
|
new_key = prefix + 'params.' + key
|
||||||
|
new_state_dict[new_key] = target_state_dict[new_key]
|
||||||
|
|
||||||
|
# check that every key in new_state_dict is in target_state_dict
|
||||||
|
for key in new_state_dict.keys():
|
||||||
|
assert key in target_state_dict, f"key {key} not in target_state_dict"
|
||||||
|
# check that all Qs keys in target_state_dict are in new_state_dict
|
||||||
|
for key in target_state_dict.keys():
|
||||||
|
if 'Qs' in key:
|
||||||
|
assert key in new_state_dict, f"key {key} not in new_state_dict"
|
||||||
|
# check that source_state_dict contains no Qs keys
|
||||||
|
for key in source_state_dict.keys():
|
||||||
|
assert 'Qs' not in key, f"key {key} contains 'Qs'"
|
||||||
|
|
||||||
|
# copy log_std_min and log_std_max from target_state_dict to new_state_dict
|
||||||
|
new_state_dict['log_std_min'] = target_state_dict['log_std_min']
|
||||||
|
new_state_dict['log_std_dif'] = target_state_dict['log_std_dif']
|
||||||
|
|
||||||
|
# copy new_state_dict to source_state_dict
|
||||||
|
source_state_dict.update(new_state_dict)
|
||||||
|
|
||||||
|
return source_state_dict
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|||||||
from common import math
|
from common import math
|
||||||
from common.scale import RunningScale
|
from common.scale import RunningScale
|
||||||
from common.world_model import WorldModel
|
from common.world_model import WorldModel
|
||||||
|
from common.layers import api_model_conversion
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
|
|
||||||
|
|
||||||
@@ -84,39 +85,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=torch.get_default_device())
|
state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=torch.get_default_device())
|
||||||
state_dict = state_dict["model"] if "model" in state_dict else state_dict
|
state_dict = state_dict["model"] if "model" in state_dict else state_dict
|
||||||
def load_sd_hook(model, local_state_dict, prefix, *args):
|
state_dict = api_model_conversion(self.model.state_dict(), state_dict)
|
||||||
name_map = [
|
|
||||||
"weight", "bias", "ln.weight", "ln.bias",
|
|
||||||
]
|
|
||||||
# print("Listing state dict keys (from disk)")
|
|
||||||
# for k in list(local_state_dict.keys()):
|
|
||||||
# print("\t", k)
|
|
||||||
|
|
||||||
sd = model.state_dict()
|
|
||||||
# print("Listing dest state dict keys")
|
|
||||||
# for k in list(sd.keys()):
|
|
||||||
# print("\t", k)
|
|
||||||
|
|
||||||
# print("Maps:")
|
|
||||||
new_sd = dict(sd)
|
|
||||||
for cur_prefix in (prefix, "_target"+prefix[:-1]+"_"):
|
|
||||||
for key, val in list(local_state_dict.items()):
|
|
||||||
if not key.startswith(cur_prefix[:-1]):
|
|
||||||
continue
|
|
||||||
num = key[len(cur_prefix + "params."):]
|
|
||||||
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
|
|
||||||
new_total_key = cur_prefix + 'params.' + new_key
|
|
||||||
# print("\t", key, '-->', new_total_key)
|
|
||||||
del local_state_dict[key]
|
|
||||||
new_sd[new_total_key] = val
|
|
||||||
if not cur_prefix.startswith("_target"):
|
|
||||||
new_total_key = "_detach" + cur_prefix[:-1] + "_" + 'params.' + new_key
|
|
||||||
# print("\t", 'DETACH', key, '-->', new_total_key)
|
|
||||||
new_sd[new_total_key] = val
|
|
||||||
local_state_dict.update(new_sd)
|
|
||||||
return local_state_dict
|
|
||||||
load_sd_hook(self.model, state_dict, "_Qs.")
|
|
||||||
assert not set(TensorDict(self.model.state_dict()).keys()).symmetric_difference(set(TensorDict(state_dict).keys()))
|
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user