diff --git a/tdmpc2/common/math.py b/tdmpc2/common/math.py index 5ac92ad..11f3d28 100644 --- a/tdmpc2/common/math.py +++ b/tdmpc2/common/math.py @@ -42,6 +42,16 @@ def squash(mu, pi, log_pi): return mu, pi, log_pi +def int_to_one_hot(x, num_classes): + """ + Converts an integer tensor to a one-hot tensor. + Supports batched inputs. + """ + one_hot = torch.zeros(*x.shape, num_classes, device=x.device) + one_hot.scatter_(-1, x.unsqueeze(-1), 1) + return one_hot + + def symlog(x): """ Symmetric logarithmic function. diff --git a/tdmpc2/common/parser.py b/tdmpc2/common/parser.py index e162eac..e931cd5 100755 --- a/tdmpc2/common/parser.py +++ b/tdmpc2/common/parser.py @@ -77,6 +77,10 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf: cfg.task_dim = 0 cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) + # Check action space compatibility + if cfg.get('action', 'continuous') == 'discrete': + assert not cfg.multitask, 'Discrete actions are not supported in multi-task settings.' + # Check torch.compile compatibility if cfg.get('compile', False): assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.' diff --git a/tdmpc2/common/world_model.py b/tdmpc2/common/world_model.py index eb9633d..3482438 100644 --- a/tdmpc2/common/world_model.py +++ b/tdmpc2/common/world_model.py @@ -2,9 +2,12 @@ from copy import deepcopy import torch import torch.nn as nn +import torch.nn.functional as F +from torch.distributions.categorical import Categorical +from tensordict.nn import TensorDictParams from common import layers, math, init -from tensordict.nn import TensorDictParams + class WorldModel(nn.Module): """ @@ -23,7 +26,7 @@ class WorldModel(nn.Module): self._encoder = layers.enc(cfg) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) - self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) + self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action == 'continuous' else cfg.action_dim) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self.apply(init.weight_init) init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]) @@ -121,15 +124,12 @@ class WorldModel(nn.Module): z = torch.cat([z, a], dim=-1) return self._reward(z) - def pi(self, z, task): + def _continuous_pi(self, z, task): """ Samples an action from the policy prior. The policy prior is a Gaussian distribution with mean and (log) std predicted by a neural network. """ - if self.cfg.multitask: - z = self.task_emb(z, task) - # Gaussian policy prior mu, log_std = self._pi(z).chunk(2, dim=-1) log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif) @@ -149,6 +149,41 @@ class WorldModel(nn.Module): return mu, pi, log_pi, log_std + def _discrete_pi(self, z, task): + """ + Samples an action from the policy prior. + The policy prior is a categorical distribution + with logits predicted by a neural network. + """ + # Categorical policy prior + logits = self._pi(z) + policy_dist = Categorical(logits=logits) + action = policy_dist.sample() + action = math.int_to_one_hot(action, self.cfg.action_dim) + + # Action probabilities for calculating the adapted soft-Q loss + action_probs = policy_dist.probs + log_prob = F.log_softmax(logits, dim=-1) + + return action, action, log_prob, action_probs + + + def pi(self, z, task): + """ + Samples an action from the policy prior. + Policy can be either continuous (Gaussian) or discrete (categorical). + """ + if self.cfg.multitask: + z = self.task_emb(z, task) + + if self.cfg.action == 'discrete': + return self._discrete_pi(z, task) + elif self.cfg.action == 'continuous': + return self._continuous_pi(z, task) + else: + raise NotImplementedError(f"Action space {self.cfg.action} not supported.") + + def Q(self, z, a, task, return_type='min', target=False, detach=False): """ Predict state-action value. diff --git a/tdmpc2/config.yaml b/tdmpc2/config.yaml index 597c829..6718a82 100755 --- a/tdmpc2/config.yaml +++ b/tdmpc2/config.yaml @@ -2,8 +2,9 @@ defaults: - override hydra/launcher: submitit_local # environment -task: dog-run +task: cartpole-swingup obs: state +action: discrete # evaluation checkpoint: ??? @@ -29,7 +30,7 @@ exp_name: default data_dir: ??? # planning -mpc: true +mpc: false iterations: 6 num_samples: 512 num_elites: 64 diff --git a/tdmpc2/envs/__init__.py b/tdmpc2/envs/__init__.py index 6326a9e..3b3f91c 100644 --- a/tdmpc2/envs/__init__.py +++ b/tdmpc2/envs/__init__.py @@ -3,6 +3,7 @@ import warnings import gym +from envs.wrappers.discrete import DiscreteWrapper from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.pixels import PixelWrapper from envs.wrappers.tensor import TensorWrapper @@ -65,6 +66,7 @@ def make_env(cfg): for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: try: env = fn(cfg) + break except ValueError: pass if env is None: @@ -72,11 +74,13 @@ def make_env(cfg): env = TensorWrapper(env) if cfg.get('obs', 'state') == 'rgb': env = PixelWrapper(cfg, env) + if cfg.get('action', 'discrete'): + env = DiscreteWrapper(env) try: # Dict cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} except: # Box cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} - cfg.action_dim = env.action_space.shape[0] + cfg.action_dim = env.action_space.n if cfg.action == 'discrete' else env.action_space.shape[0] cfg.episode_length = env.max_episode_steps cfg.seed_steps = max(1000, 5*cfg.episode_length) return env diff --git a/tdmpc2/envs/wrappers/discrete.py b/tdmpc2/envs/wrappers/discrete.py new file mode 100644 index 0000000..caf5e26 --- /dev/null +++ b/tdmpc2/envs/wrappers/discrete.py @@ -0,0 +1,35 @@ +import gym +import numpy as np +import torch + +from common import math + + +class DiscreteWrapper(gym.Wrapper): + """ + Wrapper for converting continuous action spaces to discrete via binning. + """ + + def __init__(self, env): + super().__init__(env) + self.continuous_dims = self.env.action_space.shape[0] + # Bins at [-1, 0, 1] for each dimension + # Discrete actions include all possible combinations of these bins + self.action_space = gym.spaces.Discrete(3 ** self.continuous_dims) + + def rand_act(self): + action = torch.tensor(self.action_space.sample(), dtype=torch.int64) + return math.int_to_one_hot(action, self.action_space.n) + + def _discrete_to_continuous(self, action): + # Convert a discrete action to a continuous action + # action is a one-hot encoded tensor + action = torch.argmax(action) + action = action.item() + action = [action // 3 ** i % 3 for i in range(self.continuous_dims)] + action = torch.tensor(action, dtype=torch.float32) + return (action - 1) / 1 + + def step(self, action): + action = self._discrete_to_continuous(action) + return self.env.step(action) diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index e4d8ec2..1deb61b 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -103,11 +103,11 @@ class TDMPC2(torch.nn.Module): if task is not None: task = torch.tensor([task], device=self.device) if self.cfg.mpc: - a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) + action = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) else: z = self.model.encode(obs, task) - a = self.model.pi(z, task)[int(not eval_mode)][0] - return a.cpu() + action = self.model.pi(z, task)[int(not eval_mode)][0] + return action.cpu() @torch.no_grad() def _estimate_value(self, z, actions, task): @@ -202,14 +202,17 @@ class TDMPC2(torch.nn.Module): Returns: float: Loss of the policy update. """ - _, pis, log_pis, _ = self.model.pi(zs, task) - qs = self.model.Q(zs, pis, task, return_type='avg', detach=True) + _, actions, log_probs, action_probs = self.model.pi(zs, task) + qs = self.model.Q(zs, actions, task, return_type='avg', detach=True) self.scale.update(qs[0]) qs = self.scale(qs) # Loss is a weighted sum of Q-values rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) - pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() + if self.cfg.action == 'discrete': + pi_loss = ((action_probs * (self.cfg.entropy_coef * log_probs - qs)).mean(dim=(1,2)) * rho).mean() + else: + pi_loss = ((self.cfg.entropy_coef * log_probs - qs).mean(dim=(1,2)) * rho).mean() pi_loss.backward() pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) self.pi_optim.step() diff --git a/tdmpc2/trainer/online_trainer.py b/tdmpc2/trainer/online_trainer.py index 0d2f062..7c8f3c5 100755 --- a/tdmpc2/trainer/online_trainer.py +++ b/tdmpc2/trainer/online_trainer.py @@ -54,7 +54,8 @@ class OnlineTrainer(Trainer): else: obs = obs.unsqueeze(0).cpu() if action is None: - action = torch.full_like(self.env.rand_act(), float('nan')) + action_val = -1 if self.cfg.action == 'discrete' else float('nan') + action = torch.full_like(self.env.rand_act(), action_val) if reward is None: reward = torch.tensor(float('nan')) td = TensorDict(