From c95b75565535f3c3205fd5a346bffdf10c2a3117 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Wed, 9 Apr 2025 15:55:57 -0700 Subject: [PATCH] add walker2d --- tdmpc2/common/world_model.py | 7 ++++++- tdmpc2/config.yaml | 2 +- tdmpc2/envs/mujoco.py | 4 +++- tdmpc2/tdmpc2.py | 20 +++++++++++++++++--- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index 1ce6943..91e581e 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -184,10 +184,11 @@ class WorldModel(nn.Module): `return_type` can be one of [`min`, `avg`, `all`]: - `min`: return the minimum of two randomly subsampled Q-values. - `avg`: return the average of two randomly subsampled Q-values. + - 'avg-all': return the average of all Q-values. - `all`: return all Q-values. `target` specifies whether to use the target Q-networks or not. """ - assert return_type in {'min', 'avg', 'all'} + assert return_type in {'min', 'avg', 'avg-all', 'all'} if self.cfg.multitask: z = self.task_emb(z, task) @@ -204,6 +205,10 @@ class WorldModel(nn.Module): if return_type == 'all': return out + if return_type == 'avg-all': + Q = math.two_hot_inv(out, self.cfg) + return Q.mean(0) + qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2] Q = math.two_hot_inv(out[qidx], self.cfg) if return_type == "min": diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index a077bb6..6d3510c 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -16,7 +16,7 @@ steps: 10_000_000 batch_size: 256 reward_coef: 0.1 value_coef: 0.1 -termination_coef: 1 +termination_coef: 20 consistency_coef: 20 rho: 0.5 lr: 3e-4 diff --git a/tdmpc2/envs/mujoco.py b/tdmpc2/envs/mujoco.py index e3a8f90..a41d4ad 100644 --- a/tdmpc2/envs/mujoco.py +++ b/tdmpc2/envs/mujoco.py @@ -4,6 +4,7 @@ from envs.wrappers.timeout import Timeout MUJOCO_TASKS = { + 'mujoco-walker': 'Walker2d-v4', 'mujoco-halfcheetah': 'HalfCheetah-v4', 'lunarlander-continuous': 'LunarLander-v2', } @@ -46,7 +47,8 @@ def make_env(cfg): if cfg.task == 'lunarlander-continuous': env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array') else: - env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') + env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') #, terminate_when_unhealthy=False) env = MuJoCoWrapper(env, cfg) env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000) + cfg.discount_max = 0.99 # TODO: temporarily hardcore for these envs, makes comparison to other codebases easier return env diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index a28c7ee..8ab3c2b 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -128,7 +128,6 @@ class TDMPC2(torch.nn.Module): for t in range(self.cfg.horizon): reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) z = self.model.next(z, actions[t], task) - G = G + discount * (1-termination) * reward discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount = discount * discount_update @@ -255,6 +254,7 @@ class TDMPC2(torch.nn.Module): """ action, _ = self.model.pi(next_z, task) discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount + # return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='avg-all', target=True) return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True) def _update(self, obs, action, reward, terminated, task=None): @@ -314,13 +314,27 @@ class TDMPC2(torch.nn.Module): # Return training statistics self.model.eval() + # termination classification metrics + # number of terminations in batch + termination_rate = terminated[-1].sum() / self.cfg.batch_size + # recall = TP / (TP + FN) + termination_tp = ((termination_pred > 0.5) & (terminated[-1] == 1)).sum() + termination_fn = ((termination_pred <= 0.5) & (terminated[-1] == 1)).sum() + termination_fp = ((termination_pred > 0.5) & (terminated[-1] == 0)).sum() + termination_recall = termination_tp / (termination_tp + termination_fn + 1e-9) + # precision = TP / (TP + FP) + termination_precision = termination_tp / (termination_tp + termination_fp + 1e-9) + # F1 score = 2 * (precision * recall) / (precision + recall) + termination_f1 = 2 * (termination_precision * termination_recall) / (termination_precision + termination_recall + 1e-9) info = TensorDict({ "consistency_loss": consistency_loss, "reward_loss": reward_loss, "value_loss": value_loss, "termination_loss": termination_loss, - "termination_mean": termination_pred.mean(), - "termination_mean_gt": terminated[-1].mean(), + "termination_rate": termination_rate, + "termination_recall": termination_recall, + "termination_precision": termination_precision, + "termination_f1": termination_f1, "total_loss": total_loss, "grad_norm": grad_norm, })