added save and load for optimizers
This commit is contained in:
32
tools.py
32
tools.py
@@ -970,3 +970,35 @@ def enable_deterministic_run():
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
|
||||
def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=None):
|
||||
if optimizers_state_dicts is None:
|
||||
optimizers_state_dicts = {}
|
||||
attrs = obj.__dict__
|
||||
if isinstance(obj, torch.nn.Module):
|
||||
attrs.update(
|
||||
{k: attr for k, attr in obj.named_modules() if "." not in k and obj != attr}
|
||||
)
|
||||
for name, attr in attrs.items():
|
||||
new_path = path + "." + name if path else name
|
||||
if isinstance(attr, torch.optim.Optimizer):
|
||||
optimizers_state_dicts[new_path] = attr.state_dict()
|
||||
elif hasattr(attr, "__dict__"):
|
||||
optimizers_state_dicts.update(
|
||||
recursively_collect_optim_state_dict(
|
||||
attr, new_path, optimizers_state_dicts
|
||||
)
|
||||
)
|
||||
return optimizers_state_dicts
|
||||
|
||||
|
||||
def recursively_load_optim_state_dict(obj, optimizers_state_dicts):
|
||||
print(optimizers_state_dicts)
|
||||
for path, state_dict in optimizers_state_dicts.items():
|
||||
keys = path.split(".")
|
||||
obj_now = obj
|
||||
for key in keys:
|
||||
obj_now = getattr(obj_now, key)
|
||||
print(keys)
|
||||
obj_now.load_state_dict(state_dict)
|
||||
|
||||
Reference in New Issue
Block a user