From eece80123d0f5068f9aa4460b91833242a63ab54 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Tue, 15 Apr 2025 15:55:05 -0700 Subject: [PATCH] full support for episodic rl --- tdmpc2/common/logger.py | 2 +- tdmpc2/common/math.py | 18 +++++++++++++++++- tdmpc2/common/world_model.py | 23 ++++++++--------------- tdmpc2/config.yaml | 4 ++-- tdmpc2/envs/dmcontrol.py | 3 +-- tdmpc2/envs/mujoco.py | 8 ++++++-- tdmpc2/envs/wrappers/episodic.py | 24 ------------------------ tdmpc2/tdmpc2.py | 29 +++++++++-------------------- tdmpc2/train.py | 2 -- tdmpc2/trainer/offline_trainer.py | 2 +- tdmpc2/trainer/online_trainer.py | 14 ++++++++++---- 11 files changed, 55 insertions(+), 74 deletions(-) delete mode 100644 tdmpc2/envs/wrappers/episodic.py diff --git a/tdmpc2/common/logger.py b/tdmpc2/common/logger.py index 8ea2c2e..b9b0b1f 100755 --- a/tdmpc2/common/logger.py +++ b/tdmpc2/common/logger.py @@ -16,7 +16,7 @@ CONSOLE_FORMAT = [ ("step", "I", "int"), ("episode_reward", "R", "float"), ("episode_success", "S", "float"), - ("total_time", "T", "time"), + ("elapsed_time", "T", "time"), ] CAT_TO_COLOR = { diff --git a/tdmpc2/common/math.py b/tdmpc2/common/math.py index cc37800..57a8da8 100644 --- a/tdmpc2/common/math.py +++ b/tdmpc2/common/math.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from tensordict import TensorDict def soft_ce(pred, target, cfg): @@ -84,11 +85,26 @@ def two_hot_inv(x, cfg): def gumbel_softmax_sample(p, temperature=1.0, dim=0): + """Sample from the Gumbel-Softmax distribution.""" logits = p.log() - # Generate Gumbel noise gumbels = ( -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() ) # ~Gumbel(0,1) gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau) y_soft = gumbels.softmax(dim) return y_soft.argmax(-1) + + +def termination_statistics(pred, target, eps=1e-9): + """Compute episode termination statistics.""" + pred = pred.squeeze(-1) + target = target.squeeze(-1) + rate = target.sum() / len(target) + tp = ((pred > 0.5) & (target == 1)).sum() + fn = ((pred <= 0.5) & (target == 1)).sum() + fp = ((pred > 0.5) & (target == 0)).sum() + recall = tp / (tp + fn + eps) + precision = tp / (tp + fp + eps) + f1 = 2 * (precision * recall) / (precision + recall + eps) + return TensorDict({'termination_rate': rate, + 'termination_f1': f1}) diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index d0040b8..1d7bd73 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -56,6 +56,8 @@ class WorldModel(nn.Module): repr = 'TD-MPC2 World Model\n' modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions'] for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]): + if m == self._termination and not self.cfg.episodic: + continue repr += f"{modules[i]}: {m}\n" repr += "Learnable parameters: {:,}".format(self.total_params) return repr @@ -127,16 +129,17 @@ class WorldModel(nn.Module): z = torch.cat([z, a], dim=-1) return self._reward(z) - def termination(self, z, task, sigmoid=True): + def termination(self, z, task, unnormalized=False): """ Predicts termination signal. """ assert task is None if self.cfg.multitask: z = self.task_emb(z, task) - if sigmoid: - return torch.sigmoid(self._termination(z)) - return self._termination(z) + if unnormalized: + return self._termination(z) + return torch.sigmoid(self._termination(z)) + def pi(self, z, task): """ @@ -186,12 +189,10 @@ class WorldModel(nn.Module): `return_type` can be one of [`min`, `avg`, `all`]: - `min`: return the minimum of two randomly subsampled Q-values. - `avg`: return the average of two randomly subsampled Q-values. - - 'min-all': return the minimum of all Q-values. - - 'avg-all': return the average of all Q-values. - `all`: return all Q-values. `target` specifies whether to use the target Q-networks or not. """ - assert return_type in {'min', 'avg', 'min-all', 'avg-all', 'all'} + assert return_type in {'min', 'avg', 'all'} if self.cfg.multitask: z = self.task_emb(z, task) @@ -208,14 +209,6 @@ class WorldModel(nn.Module): if return_type == 'all': return out - if return_type == 'avg-all': - Q = math.two_hot_inv(out, self.cfg) - return Q.mean(0) - - if return_type == 'min-all': - Q = math.two_hot_inv(out, self.cfg) - return Q.min(0).values - qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2] Q = math.two_hot_inv(out[qidx], self.cfg) if return_type == "min": diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index a077bb6..ac6ca5d 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -2,9 +2,9 @@ defaults: - override hydra/launcher: submitit_local # environment -task: cartpole-balance-sparse +task: dog-run obs: state -episodic: true +episodic: false # evaluation checkpoint: ??? diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index abfd7a7..a6e21b3 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -9,9 +9,8 @@ from dm_control import suite suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom') suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS) from dm_control.suite.wrappers import action_scale -from envs.wrappers.timeout import Timeout -from envs.wrappers.episodic import EpisodicWrapper +from envs.wrappers.timeout import Timeout def get_obs_shape(env): diff --git a/tdmpc2/envs/mujoco.py b/tdmpc2/envs/mujoco.py index 6c5946d..358775c 100644 --- a/tdmpc2/envs/mujoco.py +++ b/tdmpc2/envs/mujoco.py @@ -6,6 +6,7 @@ from envs.wrappers.timeout import Timeout MUJOCO_TASKS = { 'mujoco-walker': 'Walker2d-v4', 'mujoco-halfcheetah': 'HalfCheetah-v4', + 'bipedal-walker': 'BipedalWalker-v3', 'lunarlander-continuous': 'LunarLander-v2', } @@ -49,7 +50,10 @@ def make_env(cfg): else: env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') env = MuJoCoWrapper(env, cfg) - env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000) + env = Timeout(env, max_episode_steps={ + 'lunarlander-continuous': 500, + 'bipedal-walker': 1600, + }.get(cfg.task, 1000)) # Default max episode steps for other tasks cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier - cfg.rho = 0.7 # TODO: temporarily increase rho for episodic tasks + cfg.rho = 0.7 # TODO: increase rho for episodic tasks since termination always happens at the end of a sequence return env diff --git a/tdmpc2/envs/wrappers/episodic.py b/tdmpc2/envs/wrappers/episodic.py deleted file mode 100644 index b0d4794..0000000 --- a/tdmpc2/envs/wrappers/episodic.py +++ /dev/null @@ -1,24 +0,0 @@ -from collections import deque - -import gymnasium as gym -import numpy as np -import torch - - -class EpisodicWrapper(gym.Wrapper): - """ - Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment. - """ - - def __init__(self, cfg, env): - super().__init__(env) - assert cfg.task == 'cartpole-balance-sparse' - self.cfg = cfg - self.env = env - - def step(self, action): - obs, reward, done, info = self.env.step(action) - if self.cfg.episodic and reward == 0: - done = True - info['terminated'] = True - return obs, reward, done, info diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 80a72cc..1df1b77 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module): self.discount = torch.tensor( [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' ) if self.cfg.multitask else self._get_discount(cfg.episode_length) - print('Max episode length:', cfg.episode_length) + print('Episode length:', cfg.episode_length) print('Discount factor:', self.discount) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) if cfg.compile: @@ -197,7 +197,7 @@ class TDMPC2(torch.nn.Module): std = std * self.model._action_masks[task] # Select action - rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs + rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) a, std = actions[0], std[0] if not eval_mode: @@ -279,7 +279,7 @@ class TDMPC2(torch.nn.Module): _zs = zs[:-1] qs = self.model.Q(_zs, action, task, return_type='all') reward_preds = self.model.reward(_zs, action, task) - termination_pred = self.model.termination(zs[1:], task, sigmoid=False) + termination_pred = self.model.termination(zs[1:], task, unnormalized=True) # Compute losses reward_loss, value_loss = 0, 0 @@ -290,7 +290,10 @@ class TDMPC2(torch.nn.Module): consistency_loss = consistency_loss / self.cfg.horizon reward_loss = reward_loss / self.cfg.horizon - termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated) + if self.cfg.episodic: + termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated) + else: + termination_loss = 0. value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q) total_loss = ( self.cfg.consistency_coef * consistency_loss + @@ -313,30 +316,16 @@ class TDMPC2(torch.nn.Module): # Return training statistics self.model.eval() - # termination classification metrics - # number of terminations in batch - termination_rate = terminated[-1].sum() / self.cfg.batch_size - # recall = TP / (TP + FN) - termination_tp = ((termination_pred > 0.5) & (terminated[-1] == 1)).sum() - termination_fn = ((termination_pred <= 0.5) & (terminated[-1] == 1)).sum() - termination_fp = ((termination_pred > 0.5) & (terminated[-1] == 0)).sum() - termination_recall = termination_tp / (termination_tp + termination_fn + 1e-9) - # precision = TP / (TP + FP) - termination_precision = termination_tp / (termination_tp + termination_fp + 1e-9) - # F1 score = 2 * (precision * recall) / (precision + recall) - termination_f1 = 2 * (termination_precision * termination_recall) / (termination_precision + termination_recall + 1e-9) info = TensorDict({ "consistency_loss": consistency_loss, "reward_loss": reward_loss, "value_loss": value_loss, "termination_loss": termination_loss, - "termination_rate": termination_rate, - "termination_recall": termination_recall, - "termination_precision": termination_precision, - "termination_f1": termination_f1, "total_loss": total_loss, "grad_norm": grad_norm, }) + if self.cfg.episodic: + info.update(math.termination_statistics(torch.sigmoid(termination_pred[-1]), terminated[-1])) info.update(pi_info) return info.detach().mean() diff --git a/tdmpc2/train.py b/tdmpc2/train.py index b5040c1..5676349 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -48,8 +48,6 @@ def train(cfg: dict): cfg = parse_cfg(cfg) set_seed(cfg.seed) print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) - assert cfg.episodic, \ - f'This branch is experimental and only supports episodic RL tasks at this time.' trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer = trainer_cls( diff --git a/tdmpc2/trainer/offline_trainer.py b/tdmpc2/trainer/offline_trainer.py index a46d00b..a64761b 100755 --- a/tdmpc2/trainer/offline_trainer.py +++ b/tdmpc2/trainer/offline_trainer.py @@ -81,7 +81,7 @@ class OfflineTrainer(Trainer): if i % self.cfg.eval_freq == 0 or i % 10_000 == 0: metrics = { 'iteration': i, - 'total_time': time() - self._start_time, + 'elapsed_time': time() - self._start_time, } metrics.update(train_metrics) if i % self.cfg.eval_freq == 0: diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 17f4359..83128f7 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -17,15 +17,17 @@ class OnlineTrainer(Trainer): def common_metrics(self): """Return a dictionary of current metrics.""" + elapsed_time = time() - self._start_time return dict( step=self._step, episode=self._ep_idx, - total_time=time() - self._start_time, + elapsed_time=elapsed_time, + steps_per_second=self._step / elapsed_time ) def eval(self): """Evaluate a TD-MPC2 agent.""" - ep_rewards, ep_successes = [], [] + ep_rewards, ep_successes, ep_lengths = [], [], [] for i in range(self.cfg.eval_episodes): obs, done, ep_reward, t = self.env.reset(), False, 0, 0 if self.cfg.save_video: @@ -40,11 +42,13 @@ class OnlineTrainer(Trainer): self.logger.video.record(self.env) ep_rewards.append(ep_reward) ep_successes.append(info['success']) + ep_lengths.append(t) if self.cfg.save_video: self.logger.video.save(self._step) return dict( episode_reward=np.nanmean(ep_rewards), episode_success=np.nanmean(ep_successes), + episode_length= np.nanmean(ep_lengths), ) def to_td(self, obs, action=None, reward=None, terminated=None): @@ -84,12 +88,14 @@ class OnlineTrainer(Trainer): eval_next = False if self._step > 0: + if info['terminated'] and not self.cfg.episodic: + raise ValueError('Termination detected but you are not in episodic mode. ' \ + 'Set `episodic=true` to enable support for terminations.') train_metrics.update( episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), episode_success=info['success'], episode_length=len(self._tds), - episode_terminated=info['terminated'], - ) + episode_terminated=info['terminated']) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train') self._ep_idx = self.buffer.add(torch.cat(self._tds))