maximum entropy discrete policy

This commit is contained in:
Nicklas Hansen
2024-11-22 22:51:47 -08:00
parent d463268bd2
commit 4dcd933b8f
8 changed files with 51 additions and 63 deletions

View File

@@ -77,12 +77,6 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
cfg.task_dim = 0 cfg.task_dim = 0
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
# Check action space compatibility
assert cfg.action in ['continuous', 'discrete'], \
f'Invalid action space {cfg.action}. Must be one of ["continuous", "discrete"]'
if cfg.action == 'discrete':
assert not cfg.multitask, 'Discrete actions are not supported in multi-task settings.'
# Check torch.compile compatibility # Check torch.compile compatibility
if cfg.get('compile', False): if cfg.get('compile', False):
assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.' assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.'

View File

@@ -26,7 +26,7 @@ class WorldModel(nn.Module):
self._encoder = layers.enc(cfg) self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action == 'continuous' else cfg.action_dim) self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action_space == 'continuous' else cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init) self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]) init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
@@ -155,37 +155,19 @@ class WorldModel(nn.Module):
The policy prior is a categorical distribution The policy prior is a categorical distribution
with logits predicted by a neural network. with logits predicted by a neural network.
""" """
assert task is None, "Discrete policy does not support multitask."
# Categorical policy prior # Categorical policy prior
# logits = self._pi(z) logits = self._pi(z)
# policy_dist = Categorical(logits=logits) policy_dist = Categorical(logits=logits)
# action = policy_dist.sample()
# action = math.int_to_one_hot(action, self.cfg.action_dim)
# # Action probabilities for calculating the adapted soft-Q loss action = policy_dist.sample()
# action_probs = policy_dist.probs action_probs = policy_dist.probs
# log_prob = F.log_softmax(logits, dim=-1) log_prob = F.log_softmax(logits, dim=-1)
# return action, action, log_prob, action_probs one_hot_action = math.int_to_one_hot(action, self.cfg.action_dim)
# Argmax policy
# enumerate all possible one-hot actions
# and return the one with the highest Q-value
# for the given state.
actions = torch.eye(self.cfg.action_dim, device=z.device).unsqueeze(0)
if z.dim() == 2:
# z (batch_size, latent_dim) -> (batch_size, action_dim, latent_dim)
z = z.unsqueeze(1).expand(-1, self.cfg.action_dim, -1)
actions = actions.repeat(z.shape[0], 1, 1)
elif z.dim() == 3:
# z (seq_len, batch_size, latent_dim) -> (seq_len, batch_size, action_dim, latent_dim)
z = z.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)
actions = actions.unsqueeze(0).repeat(z.shape[0], z.shape[1], 1, 1)
Q = self.Q(z, actions, task, return_type='min')
action = Q.argmax(dim=-2)
action = math.int_to_one_hot(action, self.cfg.action_dim)
return action, action, None, None
return action, one_hot_action, log_prob, action_probs
def pi(self, z, task): def pi(self, z, task):
""" """
@@ -195,14 +177,13 @@ class WorldModel(nn.Module):
if self.cfg.multitask: if self.cfg.multitask:
z = self.task_emb(z, task) z = self.task_emb(z, task)
if self.cfg.action == 'discrete': if self.cfg.action_space == 'discrete':
return self._discrete_pi(z, task) return self._discrete_pi(z, task)
elif self.cfg.action == 'continuous': elif self.cfg.action_space == 'continuous':
return self._continuous_pi(z, task) return self._continuous_pi(z, task)
else: else:
raise NotImplementedError(f"Action space {self.cfg.action} not supported.") raise NotImplementedError(f"Action space {self.cfg.action} not supported.")
def Q(self, z, a, task, return_type='min', target=False, detach=False): def Q(self, z, a, task, return_type='min', target=False, detach=False):
""" """
Predict state-action value. Predict state-action value.

View File

@@ -2,9 +2,8 @@ defaults:
- override hydra/launcher: submitit_local - override hydra/launcher: submitit_local
# environment # environment
task: cartpole-swingup task: discrete-cartpole-swingup
obs: state obs: state
action: discrete
# evaluation # evaluation
checkpoint: ??? checkpoint: ???
@@ -80,6 +79,7 @@ task_title: ???
multitask: ??? multitask: ???
tasks: ??? tasks: ???
obs_shape: ??? obs_shape: ???
action_space: ???
action_dim: ??? action_dim: ???
episode_length: ??? episode_length: ???
obs_shapes: ??? obs_shapes: ???

View File

@@ -60,9 +60,13 @@ def make_env(cfg):
gym.logger.set_level(40) gym.logger.set_level(40)
if cfg.multitask: if cfg.multitask:
env = make_multitask_env(cfg) env = make_multitask_env(cfg)
else: else:
env = None env = None
if cfg.task.startswith('discrete-'):
discrete = True
cfg.task = cfg.task.replace('discrete-', '')
else:
discrete = False
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try: try:
env = fn(cfg) env = fn(cfg)
@@ -72,15 +76,18 @@ def make_env(cfg):
if env is None: if env is None:
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.') raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
env = TensorWrapper(env) env = TensorWrapper(env)
if discrete:
env = DiscreteWrapper(env)
if cfg.get('obs', 'state') == 'rgb': if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env) env = PixelWrapper(cfg, env)
if cfg.get('action', 'continuous') == 'discrete':
env = DiscreteWrapper(env)
try: # Dict try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
except: # Box except: # Box
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
cfg.action_dim = env.action_space.n if cfg.action == 'discrete' else env.action_space.shape[0] assert not isinstance(env.action_space, (gym.spaces.Dict, gym.spaces.MultiDiscrete)), \
'Dict and MultiDiscrete action spaces are not supported.'
cfg.action_space = 'discrete' if isinstance(env.action_space, gym.spaces.Discrete) else 'continuous'
cfg.action_dim = env.action_space.n if cfg.action_space == 'discrete' else env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length) cfg.seed_steps = max(1000, 5*cfg.episode_length)
return env return env

View File

@@ -1,4 +1,4 @@
from collections import deque, defaultdict from collections import defaultdict
from typing import Any, NamedTuple from typing import Any, NamedTuple
import dm_env import dm_env
import numpy as np import numpy as np

View File

@@ -13,8 +13,7 @@ class DiscreteWrapper(gym.Wrapper):
super().__init__(env) super().__init__(env)
self.bins_per_dim = bins_per_dim self.bins_per_dim = bins_per_dim
self.continuous_dims = self.env.action_space.shape[0] self.continuous_dims = self.env.action_space.shape[0]
# Bins at [-1, 0, 1] for each dimension # Equally spaced bins along each dimension
# Discrete actions include all possible combinations of these bins
self.action_space = gym.spaces.Discrete(bins_per_dim ** self.continuous_dims) self.action_space = gym.spaces.Discrete(bins_per_dim ** self.continuous_dims)
def rand_act(self): def rand_act(self):
@@ -23,7 +22,6 @@ class DiscreteWrapper(gym.Wrapper):
def _discrete_to_continuous(self, action): def _discrete_to_continuous(self, action):
# Convert a discrete action to a continuous action # Convert a discrete action to a continuous action
# action is a one-hot encoded tensor
action = torch.argmax(action) action = torch.argmax(action)
action = action.item() action = action.item()
action = [action // self.bins_per_dim ** i % self.bins_per_dim for i in range(self.continuous_dims)] action = [action // self.bins_per_dim ** i % self.bins_per_dim for i in range(self.continuous_dims)]

View File

@@ -107,7 +107,7 @@ class TDMPC2(torch.nn.Module):
else: else:
z = self.model.encode(obs, task) z = self.model.encode(obs, task)
action = self.model.pi(z, task)[int(not eval_mode)][0] action = self.model.pi(z, task)[int(not eval_mode)][0]
if self.cfg.action == 'discrete': if self.cfg.action_space == 'discrete':
action = action.squeeze(0) # TODO: this is a bit hacky action = action.squeeze(0) # TODO: this is a bit hacky
return action.cpu() return action.cpu()
@@ -122,7 +122,7 @@ class TDMPC2(torch.nn.Module):
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update discount = discount * discount_update
pi = self.model.pi(z, task)[1] pi = self.model.pi(z, task)[1]
if self.cfg.action == 'discrete': if self.cfg.action_space == 'discrete':
pi = pi.squeeze(1) # TODO: this is a bit hacky pi = pi.squeeze(1) # TODO: this is a bit hacky
return G + discount * self.model.Q(z, pi, task, return_type='avg') return G + discount * self.model.Q(z, pi, task, return_type='avg')
@@ -147,18 +147,18 @@ class TDMPC2(torch.nn.Module):
_z = z.repeat(self.cfg.num_pi_trajs, 1) _z = z.repeat(self.cfg.num_pi_trajs, 1)
for t in range(self.cfg.horizon-1): for t in range(self.cfg.horizon-1):
action = self.model.pi(_z, task)[1] action = self.model.pi(_z, task)[1]
if self.cfg.action == 'discrete': if self.cfg.action_space == 'discrete':
action = action.squeeze(1) action = action.squeeze(1)
pi_actions[t] = action pi_actions[t] = action
_z = self.model.next(_z, pi_actions[t], task) _z = self.model.next(_z, pi_actions[t], task)
action = self.model.pi(_z, task)[1] action = self.model.pi(_z, task)[1]
if self.cfg.action == 'discrete': if self.cfg.action_space == 'discrete':
action = action.squeeze(1) action = action.squeeze(1)
pi_actions[-1] = action pi_actions[-1] = action
# Initialize state and parameters # Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1) z = z.repeat(self.cfg.num_samples, 1)
if self.cfg.action == 'continuous': if self.cfg.action_space == 'continuous':
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device) mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device) std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
if not t0: if not t0:
@@ -168,7 +168,7 @@ class TDMPC2(torch.nn.Module):
actions[:, :self.cfg.num_pi_trajs] = pi_actions actions[:, :self.cfg.num_pi_trajs] = pi_actions
# Random shooting # Random shooting
if self.cfg.action == 'discrete': if self.cfg.action_space == 'discrete':
# Sample actions # Sample actions
actions_sample = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs), device=actions.device) actions_sample = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs), device=actions.device)
actions[:, self.cfg.num_pi_trajs:] = math.int_to_one_hot(actions_sample, self.cfg.action_dim) actions[:, self.cfg.num_pi_trajs:] = math.int_to_one_hot(actions_sample, self.cfg.action_dim)
@@ -235,13 +235,24 @@ class TDMPC2(torch.nn.Module):
float: Loss of the policy update. float: Loss of the policy update.
""" """
_, actions, log_probs, action_probs = self.model.pi(zs, task) _, actions, log_probs, action_probs = self.model.pi(zs, task)
if self.cfg.action_space == 'discrete':
actions = torch.eye(self.cfg.action_dim, device=zs.device).unsqueeze(0)
zs = zs.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)
actions = actions.unsqueeze(0).repeat(zs.shape[0], zs.shape[1], 1, 1)
qs = self.model.Q(zs, actions, task, return_type='avg', detach=True) qs = self.model.Q(zs, actions, task, return_type='avg', detach=True)
self.scale.update(qs[0])
if self.cfg.action_space == 'discrete':
qs = qs.squeeze(-1)
self.scale.update(torch.sum(action_probs*qs,dim=(1,2),keepdim=True)[0])
else:
self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
# Loss is a weighted sum of Q-values # Loss is a weighted sum of Q-values
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
if self.cfg.action == 'discrete': if self.cfg.action_space == 'discrete':
pi_loss = ((action_probs * (self.cfg.entropy_coef * log_probs - qs)).mean(dim=(1,2)) * rho).mean() pi_loss = ((action_probs * (self.cfg.entropy_coef * log_probs - qs)).mean(dim=(1,2)) * rho).mean()
else: else:
pi_loss = ((self.cfg.entropy_coef * log_probs - qs).mean(dim=(1,2)) * rho).mean() pi_loss = ((self.cfg.entropy_coef * log_probs - qs).mean(dim=(1,2)) * rho).mean()
@@ -266,7 +277,7 @@ class TDMPC2(torch.nn.Module):
torch.Tensor: TD-target. torch.Tensor: TD-target.
""" """
pi = self.model.pi(next_z, task)[1] pi = self.model.pi(next_z, task)[1]
if self.cfg.action == 'discrete': if self.cfg.action_space == 'discrete':
pi = pi.squeeze(2) # TODO: this is a bit hacky pi = pi.squeeze(2) # TODO: this is a bit hacky
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
@@ -318,10 +329,7 @@ class TDMPC2(torch.nn.Module):
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
# Update policy # Update policy
if self.cfg.action == 'continuous': pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task)
pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task)
else:
pi_loss, pi_grad_norm = 0., 0.
# Update target Q-functions # Update target Q-functions
self.model.soft_update_target_Q() self.model.soft_update_target_Q()

View File

@@ -54,7 +54,7 @@ class OnlineTrainer(Trainer):
else: else:
obs = obs.unsqueeze(0).cpu() obs = obs.unsqueeze(0).cpu()
if action is None: if action is None:
action_val = -1 if self.cfg.action == 'discrete' else float('nan') action_val = -1 if self.cfg.action_space == 'discrete' else float('nan')
action = torch.full_like(self.env.rand_act(), action_val) action = torch.full_like(self.env.rand_act(), action_val)
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan'))
@@ -98,7 +98,7 @@ class OnlineTrainer(Trainer):
action = self.agent.act(obs, t0=len(self._tds)==1) action = self.agent.act(obs, t0=len(self._tds)==1)
else: else:
action = self.env.rand_act() action = self.env.rand_act()
if self.cfg.action == 'discrete': if self.cfg.action_space == 'discrete':
# exploration schedule # exploration schedule
# minimum 0.01, maximum 0.05, anneal over 20k steps # minimum 0.01, maximum 0.05, anneal over 20k steps
if torch.rand(1) < 0.01 + (0.05 - 0.01) * min(1, self._step / 20000): if torch.rand(1) < 0.01 + (0.05 - 0.01) * min(1, self._step / 20000):