diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index 90196a3..e162eac 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -77,12 +77,6 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.task_dim = 0 cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) - # Check action space compatibility - assert cfg.action in ['continuous', 'discrete'], \ - f'Invalid action space {cfg.action}. Must be one of ["continuous", "discrete"]' - if cfg.action == 'discrete': - assert not cfg.multitask, 'Discrete actions are not supported in multi-task settings.' - # Check torch.compile compatibility if cfg.get('compile', False): assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.' diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index cfc3ed7..034f713 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -26,7 +26,7 @@ class WorldModel(nn.Module): self._encoder = layers.enc(cfg) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) - self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action == 'continuous' else cfg.action_dim) + self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action_space == 'continuous' else cfg.action_dim) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.apply(init.weight_init) init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]) @@ -155,37 +155,19 @@ class WorldModel(nn.Module): The policy prior is a categorical distribution with logits predicted by a neural network. """ + assert task is None, "Discrete policy does not support multitask." + # Categorical policy prior - # logits = self._pi(z) - # policy_dist = Categorical(logits=logits) - # action = policy_dist.sample() - # action = math.int_to_one_hot(action, self.cfg.action_dim) + logits = self._pi(z) + policy_dist = Categorical(logits=logits) + + action = policy_dist.sample() + action_probs = policy_dist.probs + log_prob = F.log_softmax(logits, dim=-1) - # # Action probabilities for calculating the adapted soft-Q loss - # action_probs = policy_dist.probs - # log_prob = F.log_softmax(logits, dim=-1) - - # return action, action, log_prob, action_probs - - # Argmax policy - # enumerate all possible one-hot actions - # and return the one with the highest Q-value - # for the given state. - actions = torch.eye(self.cfg.action_dim, device=z.device).unsqueeze(0) - if z.dim() == 2: - # z (batch_size, latent_dim) -> (batch_size, action_dim, latent_dim) - z = z.unsqueeze(1).expand(-1, self.cfg.action_dim, -1) - actions = actions.repeat(z.shape[0], 1, 1) - elif z.dim() == 3: - # z (seq_len, batch_size, latent_dim) -> (seq_len, batch_size, action_dim, latent_dim) - z = z.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1) - actions = actions.unsqueeze(0).repeat(z.shape[0], z.shape[1], 1, 1) - Q = self.Q(z, actions, task, return_type='min') - action = Q.argmax(dim=-2) - action = math.int_to_one_hot(action, self.cfg.action_dim) - - return action, action, None, None + one_hot_action = math.int_to_one_hot(action, self.cfg.action_dim) + return action, one_hot_action, log_prob, action_probs def pi(self, z, task): """ @@ -195,14 +177,13 @@ class WorldModel(nn.Module): if self.cfg.multitask: z = self.task_emb(z, task) - if self.cfg.action == 'discrete': + if self.cfg.action_space == 'discrete': return self._discrete_pi(z, task) - elif self.cfg.action == 'continuous': + elif self.cfg.action_space == 'continuous': return self._continuous_pi(z, task) else: raise NotImplementedError(f"Action space {self.cfg.action} not supported.") - def Q(self, z, a, task, return_type='min', target=False, detach=False): """ Predict state-action value. diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index ab0f41b..6479095 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -2,9 +2,8 @@ defaults: - override hydra/launcher: submitit_local # environment -task: cartpole-swingup +task: discrete-cartpole-swingup obs: state -action: discrete # evaluation checkpoint: ??? @@ -80,6 +79,7 @@ task_title: ??? multitask: ??? tasks: ??? obs_shape: ??? +action_space: ??? action_dim: ??? episode_length: ??? obs_shapes: ??? diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 7da75d3..09b7d4e 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -60,9 +60,13 @@ def make_env(cfg): gym.logger.set_level(40) if cfg.multitask: env = make_multitask_env(cfg) - else: env = None + if cfg.task.startswith('discrete-'): + discrete = True + cfg.task = cfg.task.replace('discrete-', '') + else: + discrete = False for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: try: env = fn(cfg) @@ -72,15 +76,18 @@ def make_env(cfg): if env is None: raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') env = TensorWrapper(env) + if discrete: + env = DiscreteWrapper(env) if cfg.get('obs', 'state') == 'rgb': env = PixelWrapper(cfg, env) - if cfg.get('action', 'continuous') == 'discrete': - env = DiscreteWrapper(env) try: # Dict cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} except: # Box cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} - cfg.action_dim = env.action_space.n if cfg.action == 'discrete' else env.action_space.shape[0] + assert not isinstance(env.action_space, (gym.spaces.Dict, gym.spaces.MultiDiscrete)), \ + 'Dict and MultiDiscrete action spaces are not supported.' + cfg.action_space = 'discrete' if isinstance(env.action_space, gym.spaces.Discrete) else 'continuous' + cfg.action_dim = env.action_space.n if cfg.action_space == 'discrete' else env.action_space.shape[0] cfg.episode_length = env.max_episode_steps cfg.seed_steps = max(1000, 5*cfg.episode_length) return env diff --git a/tdmpc2/envs/dmcontrol.py b/tdmpc2/envs/dmcontrol.py index 97be75a..c98ea32 100644 --- a/tdmpc2/envs/dmcontrol.py +++ b/tdmpc2/envs/dmcontrol.py @@ -1,4 +1,4 @@ -from collections import deque, defaultdict +from collections import defaultdict from typing import Any, NamedTuple import dm_env import numpy as np diff --git a/tdmpc2/envs/wrappers/discrete.py b/tdmpc2/envs/wrappers/discrete.py index da590b4..5253689 100644 --- a/tdmpc2/envs/wrappers/discrete.py +++ b/tdmpc2/envs/wrappers/discrete.py @@ -13,8 +13,7 @@ class DiscreteWrapper(gym.Wrapper): super().__init__(env) self.bins_per_dim = bins_per_dim self.continuous_dims = self.env.action_space.shape[0] - # Bins at [-1, 0, 1] for each dimension - # Discrete actions include all possible combinations of these bins + # Equally spaced bins along each dimension self.action_space = gym.spaces.Discrete(bins_per_dim ** self.continuous_dims) def rand_act(self): @@ -23,7 +22,6 @@ class DiscreteWrapper(gym.Wrapper): def _discrete_to_continuous(self, action): # Convert a discrete action to a continuous action - # action is a one-hot encoded tensor action = torch.argmax(action) action = action.item() action = [action // self.bins_per_dim ** i % self.bins_per_dim for i in range(self.continuous_dims)] diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 6e2e780..0e0b68c 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -107,7 +107,7 @@ class TDMPC2(torch.nn.Module): else: z = self.model.encode(obs, task) action = self.model.pi(z, task)[int(not eval_mode)][0] - if self.cfg.action == 'discrete': + if self.cfg.action_space == 'discrete': action = action.squeeze(0) # TODO: this is a bit hacky return action.cpu() @@ -122,7 +122,7 @@ class TDMPC2(torch.nn.Module): discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount = discount * discount_update pi = self.model.pi(z, task)[1] - if self.cfg.action == 'discrete': + if self.cfg.action_space == 'discrete': pi = pi.squeeze(1) # TODO: this is a bit hacky return G + discount * self.model.Q(z, pi, task, return_type='avg') @@ -147,18 +147,18 @@ class TDMPC2(torch.nn.Module): _z = z.repeat(self.cfg.num_pi_trajs, 1) for t in range(self.cfg.horizon-1): action = self.model.pi(_z, task)[1] - if self.cfg.action == 'discrete': + if self.cfg.action_space == 'discrete': action = action.squeeze(1) pi_actions[t] = action _z = self.model.next(_z, pi_actions[t], task) action = self.model.pi(_z, task)[1] - if self.cfg.action == 'discrete': + if self.cfg.action_space == 'discrete': action = action.squeeze(1) pi_actions[-1] = action # Initialize state and parameters z = z.repeat(self.cfg.num_samples, 1) - if self.cfg.action == 'continuous': + if self.cfg.action_space == 'continuous': mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) if not t0: @@ -168,7 +168,7 @@ class TDMPC2(torch.nn.Module): actions[:, :self.cfg.num_pi_trajs] = pi_actions # Random shooting - if self.cfg.action == 'discrete': + if self.cfg.action_space == 'discrete': # Sample actions actions_sample = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs), device=actions.device) actions[:, self.cfg.num_pi_trajs:] = math.int_to_one_hot(actions_sample, self.cfg.action_dim) @@ -235,13 +235,24 @@ class TDMPC2(torch.nn.Module): float: Loss of the policy update. """ _, actions, log_probs, action_probs = self.model.pi(zs, task) + + if self.cfg.action_space == 'discrete': + actions = torch.eye(self.cfg.action_dim, device=zs.device).unsqueeze(0) + zs = zs.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1) + actions = actions.unsqueeze(0).repeat(zs.shape[0], zs.shape[1], 1, 1) + qs = self.model.Q(zs, actions, task, return_type='avg', detach=True) - self.scale.update(qs[0]) + + if self.cfg.action_space == 'discrete': + qs = qs.squeeze(-1) + self.scale.update(torch.sum(action_probs*qs,dim=(1,2),keepdim=True)[0]) + else: + self.scale.update(qs[0]) qs = self.scale(qs) # Loss is a weighted sum of Q-values rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) - if self.cfg.action == 'discrete': + if self.cfg.action_space == 'discrete': pi_loss = ((action_probs * (self.cfg.entropy_coef * log_probs - qs)).mean(dim=(1,2)) * rho).mean() else: pi_loss = ((self.cfg.entropy_coef * log_probs - qs).mean(dim=(1,2)) * rho).mean() @@ -266,7 +277,7 @@ class TDMPC2(torch.nn.Module): torch.Tensor: TD-target. """ pi = self.model.pi(next_z, task)[1] - if self.cfg.action == 'discrete': + if self.cfg.action_space == 'discrete': pi = pi.squeeze(2) # TODO: this is a bit hacky discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) @@ -318,10 +329,7 @@ class TDMPC2(torch.nn.Module): self.optim.zero_grad(set_to_none=True) # Update policy - if self.cfg.action == 'continuous': - pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task) - else: - pi_loss, pi_grad_norm = 0., 0. + pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task) # Update target Q-functions self.model.soft_update_target_Q() diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 097cd61..4ab5650 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -54,7 +54,7 @@ class OnlineTrainer(Trainer): else: obs = obs.unsqueeze(0).cpu() if action is None: - action_val = -1 if self.cfg.action == 'discrete' else float('nan') + action_val = -1 if self.cfg.action_space == 'discrete' else float('nan') action = torch.full_like(self.env.rand_act(), action_val) if reward is None: reward = torch.tensor(float('nan')) @@ -98,7 +98,7 @@ class OnlineTrainer(Trainer): action = self.agent.act(obs, t0=len(self._tds)==1) else: action = self.env.rand_act() - if self.cfg.action == 'discrete': + if self.cfg.action_space == 'discrete': # exploration schedule # minimum 0.01, maximum 0.05, anneal over 20k steps if torch.rand(1) < 0.01 + (0.05 - 0.01) * min(1, self._step / 20000):