From ae4238946f65d8647a32659f046c3b0136f75dc5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 20 Jan 2025 23:49:36 +0000 Subject: [PATCH] Conversion tools for state-dicts (#55) * init * init * amend --- tdmpc2/common/buffer.py | 4 +-- tdmpc2/common/scale.py | 4 +-- tdmpc2/common/world_model.py | 7 +++-- tdmpc2/envs/__init__.py | 8 +++-- tdmpc2/evaluate.py | 9 +++--- tdmpc2/tdmpc2.py | 57 +++++++++++++++++++++++++----------- tdmpc2/train.py | 5 ++-- 7 files changed, 63 insertions(+), 31 deletions(-) diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index c23b5f8..cff2134 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -12,7 +12,7 @@ class Buffer(): def __init__(self, cfg): self.cfg = cfg - self._device = torch.device('cuda:0') + self._device = torch.get_default_device() self._capacity = min(cfg.buffer_size, cfg.steps) self._sampler = SliceSampler( num_slices=self.cfg.batch_size, @@ -59,7 +59,7 @@ class Buffer(): total_bytes = bytes_per_step*self._capacity print(f'Storage required: {total_bytes/1e9:.2f} GB') # Heuristic: decide whether to use CUDA or CPU memory - storage_device = 'cuda:0' if 2.5*total_bytes < mem_free else 'cpu' + storage_device = torch.get_default_device() if 2.5*total_bytes < mem_free else 'cpu' print(f'Using {storage_device.upper()} memory for storage.') self._storage_device = torch.device(storage_device) return self._reserve_buffer( diff --git a/tdmpc2/common/scale.py b/tdmpc2/common/scale.py index 8fd1740..0744201 100644 --- a/tdmpc2/common/scale.py +++ b/tdmpc2/common/scale.py @@ -7,8 +7,8 @@ class RunningScale(torch.nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg - self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.device('cuda'))) - self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.device('cuda'))) + self.value = Buffer(torch.ones(1, dtype=torch.float32, device=torch.get_default_device())) + self._percentiles = Buffer(torch.tensor([5, 95], dtype=torch.float32, device=torch.get_default_device())) def state_dict(self): return dict(value=self.value, percentiles=self._percentiles) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 4babde7..44d7951 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -44,8 +44,11 @@ class WorldModel(nn.Module): self._target_Qs = deepcopy(self._Qs) # Assign params to modules - self._detach_Qs.params = self._detach_Qs_params - self._target_Qs.params = self._target_Qs_params + # We do this strange assignment to avoid having duplicated tensors in the state-dict -- working on a better API for this + delattr(self._detach_Qs, "params") + self._detach_Qs.__dict__["params"] = self._detach_Qs_params + delattr(self._target_Qs, "params") + self._target_Qs.__dict__["params"] = self._target_Qs_params def __repr__(self): repr = 'TD-MPC2 World Model\n' diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 247697f..61e8088 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -9,8 +9,11 @@ from envs.wrappers.tensor import TensorWrapper def missing_dependencies(task): raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') + +from envs.dmcontrol import make_env as make_dm_control_env + try: - from envs.dmcontrol import make_env as make_dm_control_env + pass except: make_dm_control_env = missing_dependencies try: @@ -64,7 +67,8 @@ def make_env(cfg): for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: try: env = fn(cfg) - except ValueError: + except ValueError as err: + print(err) pass if env is None: raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') diff --git a/tdmpc2/evaluate.py b/tdmpc2/evaluate.py index a9f04ea..da87fbb 100755 --- a/tdmpc2/evaluate.py +++ b/tdmpc2/evaluate.py @@ -1,5 +1,5 @@ import os -os.environ['MUJOCO_GL'] = 'egl' +os.environ['MUJOCO_GL'] = os.getenv("MUJOCO_GL", 'egl') import warnings warnings.filterwarnings('ignore') @@ -29,7 +29,7 @@ def evaluate(cfg: dict): `eval_episodes`: number of episodes to evaluate on per task (default: 10) `save_video`: whether to save a video of the evaluation (default: True) `seed`: random seed (default: 1) - + See config.yaml for a full list of args. Example usage: @@ -39,7 +39,8 @@ def evaluate(cfg: dict): $ python evaluate.py task=dog-run checkpoint=/path/to/dog-1.pt save_video=true ``` """ - assert torch.cuda.is_available() + if torch.get_default_device().type == "cuda": + assert torch.cuda.is_available() assert cfg.eval_episodes > 0, 'Must evaluate at least 1 episode.' cfg = parse_cfg(cfg) set_seed(cfg.seed) @@ -57,7 +58,7 @@ def evaluate(cfg: dict): agent = TDMPC2(cfg) assert os.path.exists(cfg.checkpoint), f'Checkpoint {cfg.checkpoint} not found! Must be a valid filepath.' agent.load(cfg.checkpoint) - + # Evaluate if cfg.multitask: print(colored(f'Evaluating agent on {len(cfg.tasks)} tasks:', 'yellow', attrs=['bold'])) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index ef971c5..280598e 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -1,3 +1,5 @@ +import os + import torch import torch.nn.functional as F @@ -6,6 +8,7 @@ from common.scale import RunningScale from common.world_model import WorldModel from tensordict import TensorDict +torch.set_default_device(os.getenv("TDMPC2_DEFAULT_DEVICE", "cuda:0")) class TDMPC2(torch.nn.Module): """ @@ -17,7 +20,7 @@ class TDMPC2(torch.nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg - self.device = torch.device('cuda:0') + self.device = torch.get_default_device() self.model = WorldModel(cfg).to(self.device) self.optim = torch.optim.Adam([ {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, @@ -32,7 +35,7 @@ class TDMPC2(torch.nn.Module): self.scale = RunningScale(cfg) self.cfg.iterations += 2*int(cfg.action_dim >= 20) # Heuristic for large action spaces self.discount = torch.tensor( - [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' + [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device=torch.get_default_device() ) if self.cfg.multitask else self._get_discount(cfg.episode_length) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) if cfg.compile: @@ -82,23 +85,43 @@ class TDMPC2(torch.nn.Module): Args: fp (str or dict): Filepath or state dict to load. """ - state_dict = fp if isinstance(fp, dict) else torch.load(fp) + state_dict = fp if isinstance(fp, dict) else torch.load(fp, map_location=torch.get_default_device()) state_dict = state_dict["model"] if "model" in state_dict else state_dict - try: # Checkpoints created AFTER Nov 10 update - self.model.load_state_dict(state_dict) - except: # Backwards compatibility - def _get_submodule(state_dict, key): - return {k.replace(f"_{key}.", ""): v for k, v in state_dict.items() if k.startswith(f"_{key}.")} - for key in ["encoder", "dynamics", "reward", "pi"]: - submodule_state_dict = _get_submodule(state_dict, key) - getattr(self.model, f"_{key}").load_state_dict(submodule_state_dict) - # Q-function requires special handling - Qs_state_dict = _get_submodule(state_dict, "Qs") - # TODO: figure out how to load Q-function state_dict from old checkpoints - raise NotImplementedError("Backwards compatibility is currently broken for loading of old checkpoints, " \ - "please revert to a previous checkpoint, e.g. 88095e7899497cf7a1da36fb6bbb6bc7b5370d53 " \ - "until a fix is issued.") + def load_sd_hook(model, local_state_dict, prefix, *args): + name_map = [ + "weight", "bias", "ln.weight", "ln.bias", + ] + print("Listing state dict keys (from disk)") + for k in list(local_state_dict.keys()): + print("\t", k) + + sd = model.state_dict() + print("Listing dest state dict keys") + for k in list(sd.keys()): + print("\t", k) + + print("Maps:") + new_sd = dict(sd) + for cur_prefix in (prefix, "_target"+prefix[:-1]+"_"): + for key, val in list(local_state_dict.items()): + if not key.startswith(cur_prefix[:-1]): + continue + num = key[len(cur_prefix + "params."):] + new_key = str(int(num) // 4) + "." + name_map[int(num) % 4] + new_total_key = cur_prefix + 'params.' + new_key + print("\t", key, '-->', new_total_key) + del local_state_dict[key] + new_sd[new_total_key] = val + if not cur_prefix.startswith("_target"): + new_total_key = "_detach" + cur_prefix[:-1] + "_" + 'params.' + new_key + print("\t", 'DETACH', key, '-->', new_total_key) + new_sd[new_total_key] = val + local_state_dict.update(new_sd) + return local_state_dict + load_sd_hook(self.model, state_dict, "_Qs.") + assert not set(TensorDict(self.model.state_dict()).keys()).symmetric_difference(set(TensorDict(state_dict).keys())) self.model.load_state_dict(state_dict) + return @torch.no_grad() def act(self, obs, t0=False, eval_mode=False, task=None): diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 1846145..2bd1be1 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -1,5 +1,5 @@ import os -os.environ['MUJOCO_GL'] = 'egl' +os.environ['MUJOCO_GL'] = os.getenv("MUJOCO_GL", 'egl') os.environ['LAZY_LEGACY_OP'] = '0' os.environ['TORCHDYNAMO_INLINE_INBUILT_NN_MODULES'] = "1" os.environ['TORCH_LOGS'] = "+recompiles" @@ -43,7 +43,8 @@ def train(cfg: dict): $ python train.py task=dog-run steps=7000000 ``` """ - assert torch.cuda.is_available() + if torch.get_default_device().type == 'cuda': + assert torch.cuda.is_available() assert cfg.steps > 0, 'Must train for at least 1 step.' cfg = parse_cfg(cfg) set_seed(cfg.seed)