From 26c72119cd7d6b513e9a43e09578477cf23cf187 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 7 Jan 2024 18:16:33 -0800 Subject: [PATCH 1/2] init --- tdmpc2/common/buffer.py | 3 ++- tdmpc2/common/world_model.py | 10 ++++++++++ tdmpc2/config.yaml | 1 + tdmpc2/envs/maniskill.py | 7 +++++-- tdmpc2/envs/wrappers/tensor.py | 1 + tdmpc2/tdmpc2.py | 23 ++++++++++++++++------- tdmpc2/trainer/online_trainer.py | 12 +++++++++--- 7 files changed, 44 insertions(+), 13 deletions(-) diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index 29cc293..a060a3d 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -78,8 +78,9 @@ class Buffer(): obs = td['obs'] action = td['action'][1:] reward = td['reward'][1:].unsqueeze(-1) + terminated = td['terminated'][1:].unsqueeze(-1) task = td['task'][0] if 'task' in td.keys() else None - return self._to_device(obs, action, reward, task) + return self._to_device(obs, action, reward, terminated, task) def add(self, td): """Add an episode to the buffer.""" diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index a780ad0..4bd4a48 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -24,6 +24,7 @@ class WorldModel(nn.Module): self._encoder = layers.enc(cfg) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) + self._terminated = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1) self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.apply(init.weight_init) @@ -118,6 +119,15 @@ class WorldModel(nn.Module): z = self.task_emb(z, task) z = torch.cat([z, a], dim=-1) return self._reward(z) + + def terminated(self, z, task): + """ + Predicts termination signal. + """ + assert task is None + if self.cfg.multitask: + z = self.task_emb(z, task) + return torch.sigmoid(self._terminated(z)) def pi(self, z, task): """ diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index b720923..cd3ac11 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -15,6 +15,7 @@ steps: 10_000_000 batch_size: 256 reward_coef: 0.1 value_coef: 0.1 +terminated_coef: 0.1 consistency_coef: 20 rho: 0.5 lr: 3e-4 diff --git a/tdmpc2/envs/maniskill.py b/tdmpc2/envs/maniskill.py index 7b0b6ed..4a37b6d 100644 --- a/tdmpc2/envs/maniskill.py +++ b/tdmpc2/envs/maniskill.py @@ -47,9 +47,12 @@ class ManiSkillWrapper(gym.Wrapper): def step(self, action): reward = 0 for _ in range(2): - obs, r, _, info = self.env.step(action) + obs, r, done, info = self.env.step(action) reward += r - return obs, reward, False, info + info['terminated'] = done + if done: + break + return obs, reward, done, info @property def unwrapped(self): diff --git a/tdmpc2/envs/wrappers/tensor.py b/tdmpc2/envs/wrappers/tensor.py index 548a5f4..81bed19 100644 --- a/tdmpc2/envs/wrappers/tensor.py +++ b/tdmpc2/envs/wrappers/tensor.py @@ -37,4 +37,5 @@ class TensorWrapper(gym.Wrapper): obs, reward, done, info = self.env.step(action.numpy()) info = defaultdict(float, info) info['success'] = float(info['success']) + info['terminated'] = torch.tensor(float(info['terminated'])) return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 9d49cf8..0f09b6c 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -22,6 +22,7 @@ class TDMPC2: {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._dynamics.parameters()}, {'params': self.model._reward.parameters()}, + {'params': self.model._terminated.parameters()}, {'params': self.model._Qs.parameters()}, {'params': self.model._task_emb.parameters() if self.cfg.multitask else []} ], lr=self.cfg.lr) @@ -95,12 +96,14 @@ class TDMPC2: def _estimate_value(self, z, actions, task): """Estimate value of a trajectory starting at latent state z and executing given actions.""" G, discount = 0, 1 + terminated = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device) for t in range(self.cfg.horizon): reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) z = self.model.next(z, actions[t], task) - G += discount * reward + G += discount * (1-terminated) * reward discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount - return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') + terminated = torch.clip_(terminated + (self.model.terminated(z, task) > 0.5).float(), max=1.) + return G + discount * (1-terminated) * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') @torch.no_grad() def plan(self, z, t0=False, eval_mode=False, task=None): @@ -199,13 +202,14 @@ class TDMPC2: return pi_loss.item() @torch.no_grad() - def _td_target(self, next_z, reward, task): + def _td_target(self, next_z, reward, terminated, task): """ Compute the TD-target from a reward and the observation at the following time step. Args: next_z (torch.Tensor): Latent state at the following time step. reward (torch.Tensor): Reward at the current time step. + terminated (torch.Tensor): Termination signal at the current time step. task (torch.Tensor): Task index (only used for multi-task experiments). Returns: @@ -213,7 +217,7 @@ class TDMPC2: """ pi = self.model.pi(next_z, task)[1] discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount - return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) + return reward + discount * (1-terminated) * self.model.Q(next_z, pi, task, return_type='min', target=True) def update(self, buffer): """ @@ -225,12 +229,12 @@ class TDMPC2: Returns: dict: Dictionary of training statistics. """ - obs, action, reward, task = buffer.sample() + obs, action, reward, terminated, task = buffer.sample() # Compute targets with torch.no_grad(): next_z = self.model.encode(obs[1:], task) - td_targets = self._td_target(next_z, reward, task) + td_targets = self._td_target(next_z, reward, terminated, task) # Prepare for update self.optim.zero_grad(set_to_none=True) @@ -250,19 +254,23 @@ class TDMPC2: _zs = zs[:-1] qs = self.model.Q(_zs, action, task, return_type='all') reward_preds = self.model.reward(_zs, action, task) + terminated_preds = self.model.terminated(_zs, task) # Compute losses - reward_loss, value_loss = 0, 0 + reward_loss, terminated_loss, value_loss = 0, 0, 0 for t in range(self.cfg.horizon): reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t + terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t for q in range(self.cfg.num_q): value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t consistency_loss *= (1/self.cfg.horizon) reward_loss *= (1/self.cfg.horizon) + terminated_loss *= (1/self.cfg.horizon) value_loss *= (1/(self.cfg.horizon * self.cfg.num_q)) total_loss = ( self.cfg.consistency_coef * consistency_loss + self.cfg.reward_coef * reward_loss + + self.cfg.terminated_coef * terminated_loss + self.cfg.value_coef * value_loss ) @@ -282,6 +290,7 @@ class TDMPC2: return { "consistency_loss": float(consistency_loss.mean().item()), "reward_loss": float(reward_loss.mean().item()), + "terminated_loss": float(terminated_loss.mean().item()), "value_loss": float(value_loss.mean().item()), "pi_loss": pi_loss, "total_loss": float(total_loss.mean().item()), diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index ca33009..e365404 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -47,8 +47,10 @@ class OnlineTrainer(Trainer): episode_success=np.nanmean(ep_successes), ) - def to_td(self, obs, action=None, reward=None): + def to_td(self, obs=None, action=None, reward=None, terminated=None): """Creates a TensorDict for a new episode.""" + if obs is None: + obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan')) if isinstance(obs, dict): obs = TensorDict(obs, batch_size=(), device='cpu') else: @@ -57,12 +59,15 @@ class OnlineTrainer(Trainer): action = torch.full_like(self.env.rand_act(), float('nan')) if reward is None: reward = torch.tensor(float('nan')) + if terminated is None: + terminated = torch.tensor(float('nan')) td = TensorDict(dict( obs=obs, action=action.unsqueeze(0), reward=reward.unsqueeze(0), + terminated=terminated.unsqueeze(0), ), batch_size=(1,)) - return td + return td def train(self): """Train a TD-MPC2 agent.""" @@ -88,6 +93,7 @@ class OnlineTrainer(Trainer): ) train_metrics.update(self.common_metrics()) self.logger.log(train_metrics, 'train') + self._tds.append(self.to_td()) # Separate episodes with NaNs self._ep_idx = self.buffer.add(torch.cat(self._tds)) obs = self.env.reset() @@ -99,7 +105,7 @@ class OnlineTrainer(Trainer): else: action = self.env.rand_act() obs, reward, done, info = self.env.step(action) - self._tds.append(self.to_td(obs, action, reward)) + self._tds.append(self.to_td(obs, action, reward, info['terminated'])) # Update agent if self._step >= self.cfg.seed_steps: From fabf01a5ec92e7cc5793ee9cbbced881a9c13798 Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Sun, 7 Jan 2024 19:28:41 -0800 Subject: [PATCH 2/2] solves episodic variant of cartpole-balance-sparse --- tdmpc2/common/buffer.py | 2 +- tdmpc2/config.yaml | 3 ++- tdmpc2/envs/dmcontrol.py | 4 ++++ tdmpc2/envs/maniskill.py | 7 ++----- tdmpc2/envs/wrappers/episodic.py | 24 ++++++++++++++++++++++++ tdmpc2/tdmpc2.py | 10 +++++----- tdmpc2/train.py | 3 +++ 7 files changed, 41 insertions(+), 12 deletions(-) create mode 100644 tdmpc2/envs/wrappers/episodic.py diff --git a/tdmpc2/common/buffer.py b/tdmpc2/common/buffer.py index a060a3d..139348b 100644 --- a/tdmpc2/common/buffer.py +++ b/tdmpc2/common/buffer.py @@ -78,7 +78,7 @@ class Buffer(): obs = td['obs'] action = td['action'][1:] reward = td['reward'][1:].unsqueeze(-1) - terminated = td['terminated'][1:].unsqueeze(-1) + terminated = td['terminated'][-1].unsqueeze(-1) task = td['task'][0] if 'task' in td.keys() else None return self._to_device(obs, action, reward, terminated, task) diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index cd3ac11..2288dfb 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -2,8 +2,9 @@ defaults: - override hydra/launcher: submitit_local # environment -task: dog-run +task: cartpole-balance-sparse obs: state +episodic: true # evaluation checkpoint: ??? diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index 97be75a..bda699c 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -10,6 +10,8 @@ from dm_control.suite.wrappers import action_scale from dm_env import StepType, specs import gym +from envs.wrappers.episodic import EpisodicWrapper + class ExtendedTimeStep(NamedTuple): step_type: Any @@ -197,4 +199,6 @@ def make_env(cfg): env = action_scale.Wrapper(env, minimum=-1., maximum=1.) env = ExtendedTimeStepWrapper(env) env = TimeStepToGymWrapper(env, domain, task) + if cfg.episodic: + env = EpisodicWrapper(cfg, env) return env diff --git a/tdmpc2/envs/maniskill.py b/tdmpc2/envs/maniskill.py index 4a37b6d..7b0b6ed 100644 --- a/tdmpc2/envs/maniskill.py +++ b/tdmpc2/envs/maniskill.py @@ -47,12 +47,9 @@ class ManiSkillWrapper(gym.Wrapper): def step(self, action): reward = 0 for _ in range(2): - obs, r, done, info = self.env.step(action) + obs, r, _, info = self.env.step(action) reward += r - info['terminated'] = done - if done: - break - return obs, reward, done, info + return obs, reward, False, info @property def unwrapped(self): diff --git a/tdmpc2/envs/wrappers/episodic.py b/tdmpc2/envs/wrappers/episodic.py new file mode 100644 index 0000000..f27c8fc --- /dev/null +++ b/tdmpc2/envs/wrappers/episodic.py @@ -0,0 +1,24 @@ +from collections import deque + +import gym +import numpy as np +import torch + + +class EpisodicWrapper(gym.Wrapper): + """ + Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment. + """ + + def __init__(self, cfg, env): + super().__init__(env) + assert cfg.task == 'cartpole-balance-sparse' + self.cfg = cfg + self.env = env + + def step(self, action): + obs, reward, done, info = self.env.step(action) + if self.cfg.episodic and reward == 0: + done = True + info['terminated'] = True + return obs, reward, done, info diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 0f09b6c..3abdd1f 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -230,7 +230,7 @@ class TDMPC2: dict: Dictionary of training statistics. """ obs, action, reward, terminated, task = buffer.sample() - + # Compute targets with torch.no_grad(): next_z = self.model.encode(obs[1:], task) @@ -254,15 +254,15 @@ class TDMPC2: _zs = zs[:-1] qs = self.model.Q(_zs, action, task, return_type='all') reward_preds = self.model.reward(_zs, action, task) - terminated_preds = self.model.terminated(_zs, task) - + terminated_pred = self.model.terminated(zs[-1], task) + # Compute losses - reward_loss, terminated_loss, value_loss = 0, 0, 0 + reward_loss, value_loss = 0, 0 for t in range(self.cfg.horizon): reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t - terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t for q in range(self.cfg.num_q): value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t + terminated_loss = F.binary_cross_entropy(terminated_pred, terminated) consistency_loss *= (1/self.cfg.horizon) reward_loss *= (1/self.cfg.horizon) terminated_loss *= (1/self.cfg.horizon) diff --git a/tdmpc2/train.py b/tdmpc2/train.py index 5953bb2..1680871 100755 --- a/tdmpc2/train.py +++ b/tdmpc2/train.py @@ -46,6 +46,9 @@ def train(cfg: dict): set_seed(cfg.seed) print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) + assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \ + f'This branch is experimental and only supports cartpole-balance-sparse at this time.' + trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer = trainer_cls( cfg=cfg,