From 81eb17068e8dbb3e0f7f19c747901194260d4b64 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Tue, 8 Apr 2025 19:15:31 -0700 Subject: [PATCH] QoL improvements to termination signal debugging --- docker/environment.yaml | 6 +++- tdmpc2/common/world_model.py | 10 +++--- tdmpc2/config.yaml | 4 +-- tdmpc2/envs/__init__.py | 14 ++++++--- tdmpc2/envs/mujoco.py | 52 ++++++++++++++++++++++++++++++++ tdmpc2/tdmpc2.py | 21 ++++++++----- tdmpc2/train.py | 5 ++- tdmpc2/trainer/online_trainer.py | 2 ++ 8 files changed, 90 insertions(+), 24 deletions(-) create mode 100644 tdmpc2/envs/mujoco.py diff --git a/docker/environment.yaml b/docker/environment.yaml index 45eeec7..ecc342b 100644 --- a/docker/environment.yaml +++ b/docker/environment.yaml @@ -1,4 +1,4 @@ -name: episodic +name: tdmpc2 channels: - pytorch-nightly - nvidia @@ -55,3 +55,7 @@ dependencies: # MyoSuite: # - myosuite #################### + # Classic MuJoCo/Box2d: + # - swig + # - gymnasium[box2d] + #################### diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 34ef614..1ce6943 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -25,7 +25,7 @@ class WorldModel(nn.Module): self._encoder = layers.enc(cfg) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) - self._terminated = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) + self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.apply(init.weight_init) @@ -54,8 +54,8 @@ class WorldModel(nn.Module): def __repr__(self): repr = 'TD-MPC2 World Model\n' - modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions'] - for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]): + modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions'] + for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]): repr += f"{modules[i]}: {m}\n" repr += "Learnable parameters: {:,}".format(self.total_params) return repr @@ -127,14 +127,14 @@ class WorldModel(nn.Module): z = torch.cat([z, a], dim=-1) return self._reward(z) - def terminated(self, z, task): + def termination(self, z, task): """ Predicts termination signal. """ assert task is None if self.cfg.multitask: z = self.task_emb(z, task) - return torch.sigmoid(self._terminated(z)) + return torch.sigmoid(self._termination(z)) def pi(self, z, task): """ diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index 5f62907..a077bb6 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -16,7 +16,7 @@ steps: 10_000_000 batch_size: 256 reward_coef: 0.1 value_coef: 0.1 -terminated_coef: 0.1 +termination_coef: 1 consistency_coef: 20 rho: 0.5 lr: 3e-4 @@ -90,4 +90,4 @@ seed_steps: ??? bin_size: ??? # speedups -compile: False +compile: false diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 7b1e0a3..46f99a8 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -9,10 +9,10 @@ from envs.wrappers.tensor import TensorWrapper def missing_dependencies(task): raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.') -# try: -from envs.dmcontrol import make_env as make_dm_control_env -# except: - # make_dm_control_env = missing_dependencies +try: + from envs.dmcontrol import make_env as make_dm_control_env +except: + make_dm_control_env = missing_dependencies try: from envs.maniskill import make_env as make_maniskill_env except: @@ -25,6 +25,10 @@ try: from envs.myosuite import make_env as make_myosuite_env except: make_myosuite_env = missing_dependencies +try: + from envs.mujoco import make_env as make_mujoco_env +except: + make_mujoco_env = missing_dependencies warnings.filterwarnings('ignore', category=DeprecationWarning) @@ -61,7 +65,7 @@ def make_env(cfg): else: env = None - for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: + for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]: try: env = fn(cfg) except ValueError: diff --git a/tdmpc2/envs/mujoco.py b/tdmpc2/envs/mujoco.py new file mode 100644 index 0000000..e3a8f90 --- /dev/null +++ b/tdmpc2/envs/mujoco.py @@ -0,0 +1,52 @@ +import numpy as np +import gymnasium as gym +from envs.wrappers.timeout import Timeout + + +MUJOCO_TASKS = { + 'mujoco-halfcheetah': 'HalfCheetah-v4', + 'lunarlander-continuous': 'LunarLander-v2', +} + +class MuJoCoWrapper(gym.Wrapper): + def __init__(self, env, cfg): + super().__init__(env) + self.env = env + self.cfg = cfg + self._cumulative_reward = 0 + + def reset(self): + self._cumulative_reward = 0 + return self.env.reset()[0] + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action.copy()) + self._cumulative_reward += reward + done = terminated or truncated + info['terminated'] = terminated + if self.cfg.task == 'lunarlander-continuous': + info['success'] = self._cumulative_reward > 200 + return obs, reward, done, info + + @property + def unwrapped(self): + return self.env.unwrapped + + def render(self, **kwargs): + return self.env.render(**kwargs) + + +def make_env(cfg): + """ + Make classic/MuJoCo environment. + """ + if not cfg.task in MUJOCO_TASKS: + raise ValueError('Unknown task:', cfg.task) + assert cfg.obs == 'state', 'This task only supports state observations.' + if cfg.task == 'lunarlander-continuous': + env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array') + else: + env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') + env = MuJoCoWrapper(env, cfg) + env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000) + return env diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 6f44e15..a28c7ee 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -24,7 +24,7 @@ class TDMPC2(torch.nn.Module): {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._dynamics.parameters()}, {'params': self.model._reward.parameters()}, - {'params': self.model._terminated.parameters()}, + {'params': self.model._termination.parameters()}, {'params': self.model._Qs.parameters()}, {'params': self.model._task_emb.parameters() if self.cfg.multitask else [] } @@ -36,6 +36,8 @@ class TDMPC2(torch.nn.Module): self.discount = torch.tensor( [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' ) if self.cfg.multitask else self._get_discount(cfg.episode_length) + print('Episode length:', cfg.episode_length) + print('Discount factor:', self.discount) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) if cfg.compile: print('Compiling update function with torch.compile...') @@ -122,17 +124,17 @@ class TDMPC2(torch.nn.Module): def _estimate_value(self, z, actions, task): """Estimate value of a trajectory starting at latent state z and executing given actions.""" G, discount = 0, 1 - terminated = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device) + termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device) for t in range(self.cfg.horizon): reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) z = self.model.next(z, actions[t], task) - G = G + discount * (1-terminated) * reward + G = G + discount * (1-termination) * reward discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount = discount * discount_update - terminated = torch.clip(terminated + (self.model.terminated(z, task) > 0.5).float(), max=1.) + termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.) action, _ = self.model.pi(z, task) - return G + discount * (1-terminated) * self.model.Q(z, action, task, return_type='avg') + return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg') @torch.no_grad() def _plan(self, obs, t0=False, eval_mode=False, task=None): @@ -278,7 +280,7 @@ class TDMPC2(torch.nn.Module): _zs = zs[:-1] qs = self.model.Q(_zs, action, task, return_type='all') reward_preds = self.model.reward(_zs, action, task) - terminated_pred = self.model.terminated(zs[-1], task) + termination_pred = self.model.termination(zs[-1], task) # Compute losses reward_loss, value_loss = 0, 0 @@ -289,12 +291,12 @@ class TDMPC2(torch.nn.Module): consistency_loss = consistency_loss / self.cfg.horizon reward_loss = reward_loss / self.cfg.horizon - terminated_loss = F.binary_cross_entropy(terminated_pred, terminated[-1]) + termination_loss = F.binary_cross_entropy(termination_pred, terminated[-1]) value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q) total_loss = ( self.cfg.consistency_coef * consistency_loss + self.cfg.reward_coef * reward_loss + - self.cfg.terminated_coef * terminated_loss + + self.cfg.termination_coef * termination_loss + self.cfg.value_coef * value_loss ) @@ -316,6 +318,9 @@ class TDMPC2(torch.nn.Module): "consistency_loss": consistency_loss, "reward_loss": reward_loss, "value_loss": value_loss, + "termination_loss": termination_loss, + "termination_mean": termination_pred.mean(), + "termination_mean_gt": terminated[-1].mean(), "total_loss": total_loss, "grad_norm": grad_norm, }) diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 09813fa..b5040c1 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -48,9 +48,8 @@ def train(cfg: dict): cfg = parse_cfg(cfg) set_seed(cfg.seed) print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) - - assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \ - f'This branch is experimental and only supports cartpole-balance-sparse at this time.' + assert cfg.episodic, \ + f'This branch is experimental and only supports episodic RL tasks at this time.' trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer = trainer_cls( diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 84991e7..17f4359 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -87,6 +87,8 @@ class OnlineTrainer(Trainer): train_metrics.update( episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), episode_success=info['success'], + episode_length=len(self._tds), + episode_terminated=info['terminated'], ) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train')