Conversion tools for state-dicts (#55)

* init

* init

* amend
This commit is contained in:
Vincent Moens
2025-01-20 23:49:36 +00:00
committed by GitHub
parent a19f91c0b5
commit ae4238946f
7 changed files with 63 additions and 31 deletions

View File

@@ -12,7 +12,7 @@ class Buffer():
def __init__(self, cfg): def __init__(self, cfg):
self.cfg = cfg self.cfg = cfg
self._device = torch.device('cuda:0') self._device = torch.get_default_device()
self._capacity = min(cfg.buffer_size, cfg.steps) self._capacity = min(cfg.buffer_size, cfg.steps)
self._sampler = SliceSampler( self._sampler = SliceSampler(
num_slices=self.cfg.batch_size, num_slices=self.cfg.batch_size,
@@ -59,7 +59,7 @@ class Buffer():
total_bytes = bytes_per_step*self._capacity total_bytes = bytes_per_step*self._capacity
print(f'Storage required: {total_bytes/1e9:.2f} GB') print(f'Storage required: {total_bytes/1e9:.2f} GB')
# Heuristic: decide whether to use CUDA or CPU memory # Heuristic: decide whether to use CUDA or CPU memory
storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu' storage_device = torch.get_default_device() if 2.5*total_bytes < mem_free else 'cpu'
print(f'Using {storage_device.upper()} memory for storage.') print(f'Using {storage_device.upper()} memory for storage.')
self._storage_device = torch.device(storage_device) self._storage_device = torch.device(storage_device)
return self._reserve_buffer( return self._reserve_buffer(

View File

@@ -7,8 +7,8 @@ class RunningScale(torch.nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))) self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.get_default_device()))
self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))) self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.get_default_device()))
def state_dict(self): def state_dict(self):
return dict(value=self.value, percentiles=self._percentiles) return dict(value=self.value, percentiles=self._percentiles)

View File

@@ -44,8 +44,11 @@ class WorldModel(nn.Module):
self._target_Qs = deepcopy(self._Qs) self._target_Qs = deepcopy(self._Qs)
# Assign params to modules # Assign params to modules
self._detach_Qs.params = self._detach_Qs_params # We do this strange assignment to avoid having duplicated tensors in the state-dict -- working on a better API for this
self._target_Qs.params = self._target_Qs_params delattr(self._detach_Qs, "params")
self._detach_Qs.__dict__["params"] = self._detach_Qs_params
delattr(self._target_Qs, "params")
self._target_Qs.__dict__["params"] = self._target_Qs_params
def __repr__(self): def __repr__(self):
repr = 'TD-MPC2 World Model\n' repr = 'TD-MPC2 World Model\n'

View File

@@ -9,8 +9,11 @@ from envs.wrappers.tensor import TensorWrapper
def missing_dependencies(task): def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
try:
from envs.dmcontrol import make_env as make_dm_control_env from envs.dmcontrol import make_env as make_dm_control_env
try:
pass
except: except:
make_dm_control_env = missing_dependencies make_dm_control_env = missing_dependencies
try: try:
@@ -64,7 +67,8 @@ def make_env(cfg):
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try: try:
env = fn(cfg) env = fn(cfg)
except ValueError: except ValueError as err:
print(err)
pass pass
if env is None: if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')

View File

@@ -1,5 +1,5 @@
import os import os
os.environ['MUJOCO_GL'] = 'egl' os.environ['MUJOCO_GL'] = os.getenv("MUJOCO_GL", 'egl')
import warnings import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@@ -39,6 +39,7 @@ def evaluate(cfg: dict):
$ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true $ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true
``` ```
""" """
if torch.get_default_device().type == "cuda":
assert torch.cuda.is_available() assert torch.cuda.is_available()
assert cfg.eval_episodes > 0, 'Must evaluate at least 1 episode.' assert cfg.eval_episodes > 0, 'Must evaluate at least 1 episode.'
cfg = parse_cfg(cfg) cfg = parse_cfg(cfg)

View File

@@ -1,3 +1,5 @@
import os
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -6,6 +8,7 @@ from common.scale import RunningScale
from common.world_model import WorldModel from common.world_model import WorldModel
from tensordict import TensorDict from tensordict import TensorDict
torch.set_default_device(os.getenv("TDMPC2_DEFAULT_DEVICE", "cuda:0"))
class TDMPC2(torch.nn.Module): class TDMPC2(torch.nn.Module):
""" """
@@ -17,7 +20,7 @@ class TDMPC2(torch.nn.Module):
def __init__(self, cfg): def __init__(self, cfg):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
self.device = torch.device('cuda:0') self.device = torch.get_default_device()
self.model = WorldModel(cfg).to(self.device) self.model = WorldModel(cfg).to(self.device)
self.optim = torch.optim.Adam([ self.optim = torch.optim.Adam([
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
@@ -32,7 +35,7 @@ class TDMPC2(torch.nn.Module):
self.scale = RunningScale(cfg) self.scale = RunningScale(cfg)
self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces
self.discount = torch.tensor( self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=torch.get_default_device()
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile: if cfg.compile:
@@ -82,23 +85,43 @@ class TDMPC2(torch.nn.Module):
Args: Args:
fp (str or dict): Filepath or state dict to load. fp (str or dict): Filepath or state dict to load.
""" """
state_dict = fp if isinstance(fp, dict) else torch.load(fp) state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=torch.get_default_device())
state_dict = state_dict["model"] if "model" in state_dict else state_dict state_dict = state_dict["model"] if "model" in state_dict else state_dict
try: # Checkpoints created AFTER Nov 10 update def load_sd_hook(model, local_state_dict, prefix, *args):
self.model.load_state_dict(state_dict) name_map = [
except: # Backwards compatibility "weight", "bias", "ln.weight", "ln.bias",
def _get_submodule(state_dict, key): ]
return {k.replace(f"_{key}.", ""): v for k, v in state_dict.items() if k.startswith(f"_{key}.")} print("Listing state dict keys (from disk)")
for key in ["encoder", "dynamics", "reward", "pi"]: for k in list(local_state_dict.keys()):
submodule_state_dict = _get_submodule(state_dict, key) print("\t", k)
getattr(self.model, f"_{key}").load_state_dict(submodule_state_dict)
# Q-function requires special handling sd = model.state_dict()
Qs_state_dict = _get_submodule(state_dict, "Qs") print("Listing dest state dict keys")
# TODO: figure out how to load Q-function state_dict from old checkpoints for k in list(sd.keys()):
raise NotImplementedError("Backwards compatibility is currently broken for loading of old checkpoints, " \ print("\t", k)
"please revert to a previous checkpoint, e.g. 88095e7899497cf7a1da36fb6bbb6bc7b5370d53 " \
"until a fix is issued.") print("Maps:")
new_sd = dict(sd)
for cur_prefix in (prefix, "_target"+prefix[:-1]+"_"):
for key, val in list(local_state_dict.items()):
if not key.startswith(cur_prefix[:-1]):
continue
num = key[len(cur_prefix + "params."):]
new_key = str(int(num) // 4) + "." + name_map[int(num) % 4]
new_total_key = cur_prefix + 'params.' + new_key
print("\t", key, '-->', new_total_key)
del local_state_dict[key]
new_sd[new_total_key] = val
if not cur_prefix.startswith("_target"):
new_total_key = "_detach" + cur_prefix[:-1] + "_" + 'params.' + new_key
print("\t", 'DETACH', key, '-->', new_total_key)
new_sd[new_total_key] = val
local_state_dict.update(new_sd)
return local_state_dict
load_sd_hook(self.model, state_dict, "_Qs.")
assert not set(TensorDict(self.model.state_dict()).keys()).symmetric_difference(set(TensorDict(state_dict).keys()))
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
return
@torch.no_grad() @torch.no_grad()
def act(self, obs, t0=False, eval_mode=False, task=None): def act(self, obs, t0=False, eval_mode=False, task=None):

View File

@@ -1,5 +1,5 @@
import os import os
os.environ['MUJOCO_GL'] = 'egl' os.environ['MUJOCO_GL'] = os.getenv("MUJOCO_GL", 'egl')
os.environ['LAZY_LEGACY_OP'] = '0' os.environ['LAZY_LEGACY_OP'] = '0'
os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1" os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1"
os.environ['TORCH_LOGS'] = "+recompiles" os.environ['TORCH_LOGS'] = "+recompiles"
@@ -43,6 +43,7 @@ def train(cfg: dict):
$ python train.py task=dog-run steps=7000000 $ python train.py task=dog-run steps=7000000
``` ```
""" """
if torch.get_default_device().type == 'cuda':
assert torch.cuda.is_available() assert torch.cuda.is_available()
assert cfg.steps > 0, 'Must train for at least 1 step.' assert cfg.steps > 0, 'Must train for at least 1 step.'
cfg = parse_cfg(cfg) cfg = parse_cfg(cfg)