This commit is contained in:
Nicklas Hansen
2024-11-11 18:13:24 -08:00
parent 1bfbcb7794
commit dee034070e
8 changed files with 109 additions and 16 deletions

View File

@@ -42,6 +42,16 @@ def squash(mu, pi, log_pi):
return mu, pi, log_pi return mu, pi, log_pi
def int_to_one_hot(x, num_classes):
"""
Converts an integer tensor to a one-hot tensor.
Supports batched inputs.
"""
one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return one_hot
def symlog(x): def symlog(x):
""" """
Symmetric logarithmic function. Symmetric logarithmic function.

View File

@@ -77,6 +77,10 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
cfg.task_dim = 0 cfg.task_dim = 0
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task]) cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
# Check action space compatibility
if cfg.get('action', 'continuous') == 'discrete':
assert not cfg.multitask, 'Discrete actions are not supported in multi-task settings.'
# Check torch.compile compatibility # Check torch.compile compatibility
if cfg.get('compile', False): if cfg.get('compile', False):
assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.' assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.'

View File

@@ -2,9 +2,12 @@ from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from tensordict.nn import TensorDictParams
from common import layers, math, init from common import layers, math, init
from tensordict.nn import TensorDictParams
class WorldModel(nn.Module): class WorldModel(nn.Module):
""" """
@@ -23,7 +26,7 @@ class WorldModel(nn.Module):
self._encoder = layers.enc(cfg) self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action == 'continuous' else cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init) self.apply(init.weight_init)
init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]]) init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
@@ -121,15 +124,12 @@ class WorldModel(nn.Module):
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
return self._reward(z) return self._reward(z)
def pi(self, z, task): def _continuous_pi(self, z, task):
""" """
Samples an action from the policy prior. Samples an action from the policy prior.
The policy prior is a Gaussian distribution with The policy prior is a Gaussian distribution with
mean and (log) std predicted by a neural network. mean and (log) std predicted by a neural network.
""" """
if self.cfg.multitask:
z = self.task_emb(z, task)
# Gaussian policy prior # Gaussian policy prior
mu, log_std = self._pi(z).chunk(2, dim=-1) mu, log_std = self._pi(z).chunk(2, dim=-1)
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif) log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
@@ -149,6 +149,41 @@ class WorldModel(nn.Module):
return mu, pi, log_pi, log_std return mu, pi, log_pi, log_std
def _discrete_pi(self, z, task):
"""
Samples an action from the policy prior.
The policy prior is a categorical distribution
with logits predicted by a neural network.
"""
# Categorical policy prior
logits = self._pi(z)
policy_dist = Categorical(logits=logits)
action = policy_dist.sample()
action = math.int_to_one_hot(action, self.cfg.action_dim)
# Action probabilities for calculating the adapted soft-Q loss
action_probs = policy_dist.probs
log_prob = F.log_softmax(logits, dim=-1)
return action, action, log_prob, action_probs
def pi(self, z, task):
"""
Samples an action from the policy prior.
Policy can be either continuous (Gaussian) or discrete (categorical).
"""
if self.cfg.multitask:
z = self.task_emb(z, task)
if self.cfg.action == 'discrete':
return self._discrete_pi(z, task)
elif self.cfg.action == 'continuous':
return self._continuous_pi(z, task)
else:
raise NotImplementedError(f"Action space {self.cfg.action} not supported.")
def Q(self, z, a, task, return_type='min', target=False, detach=False): def Q(self, z, a, task, return_type='min', target=False, detach=False):
""" """
Predict state-action value. Predict state-action value.

View File

@@ -2,8 +2,9 @@ defaults:
- override hydra/launcher: submitit_local - override hydra/launcher: submitit_local
# environment # environment
task: dog-run task: cartpole-swingup
obs: state obs: state
action: discrete
# evaluation # evaluation
checkpoint: ??? checkpoint: ???
@@ -29,7 +30,7 @@ exp_name: default
data_dir: ??? data_dir: ???
# planning # planning
mpc: true mpc: false
iterations: 6 iterations: 6
num_samples: 512 num_samples: 512
num_elites: 64 num_elites: 64

View File

@@ -3,6 +3,7 @@ import warnings
import gym import gym
from envs.wrappers.discrete import DiscreteWrapper
from envs.wrappers.multitask import MultitaskWrapper from envs.wrappers.multitask import MultitaskWrapper
from envs.wrappers.pixels import PixelWrapper from envs.wrappers.pixels import PixelWrapper
from envs.wrappers.tensor import TensorWrapper from envs.wrappers.tensor import TensorWrapper
@@ -65,6 +66,7 @@ def make_env(cfg):
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]: for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
try: try:
env = fn(cfg) env = fn(cfg)
break
except ValueError: except ValueError:
pass pass
if env is None: if env is None:
@@ -72,11 +74,13 @@ def make_env(cfg):
env = TensorWrapper(env) env = TensorWrapper(env)
if cfg.get('obs', 'state') == 'rgb': if cfg.get('obs', 'state') == 'rgb':
env = PixelWrapper(cfg, env) env = PixelWrapper(cfg, env)
if cfg.get('action', 'discrete'):
env = DiscreteWrapper(env)
try: # Dict try: # Dict
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()} cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
except: # Box except: # Box
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape} cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
cfg.action_dim = env.action_space.shape[0] cfg.action_dim = env.action_space.n if cfg.action == 'discrete' else env.action_space.shape[0]
cfg.episode_length = env.max_episode_steps cfg.episode_length = env.max_episode_steps
cfg.seed_steps = max(1000, 5*cfg.episode_length) cfg.seed_steps = max(1000, 5*cfg.episode_length)
return env return env

View File

@@ -0,0 +1,35 @@
import gym
import numpy as np
import torch
from common import math
class DiscreteWrapper(gym.Wrapper):
"""
Wrapper for converting continuous action spaces to discrete via binning.
"""
def __init__(self, env):
super().__init__(env)
self.continuous_dims = self.env.action_space.shape[0]
# Bins at [-1, 0, 1] for each dimension
# Discrete actions include all possible combinations of these bins
self.action_space = gym.spaces.Discrete(3 ** self.continuous_dims)
def rand_act(self):
action = torch.tensor(self.action_space.sample(), dtype=torch.int64)
return math.int_to_one_hot(action, self.action_space.n)
def _discrete_to_continuous(self, action):
# Convert a discrete action to a continuous action
# action is a one-hot encoded tensor
action = torch.argmax(action)
action = action.item()
action = [action // 3 ** i % 3 for i in range(self.continuous_dims)]
action = torch.tensor(action, dtype=torch.float32)
return (action - 1) / 1
def step(self, action):
action = self._discrete_to_continuous(action)
return self.env.step(action)

View File

@@ -103,11 +103,11 @@ class TDMPC2(torch.nn.Module):
if task is not None: if task is not None:
task = torch.tensor([task], device=self.device) task = torch.tensor([task], device=self.device)
if self.cfg.mpc: if self.cfg.mpc:
a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task) action = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
else: else:
z = self.model.encode(obs, task) z = self.model.encode(obs, task)
a = self.model.pi(z, task)[int(not eval_mode)][0] action = self.model.pi(z, task)[int(not eval_mode)][0]
return a.cpu() return action.cpu()
@torch.no_grad() @torch.no_grad()
def _estimate_value(self, z, actions, task): def _estimate_value(self, z, actions, task):
@@ -202,14 +202,17 @@ class TDMPC2(torch.nn.Module):
Returns: Returns:
float: Loss of the policy update. float: Loss of the policy update.
""" """
_, pis, log_pis, _ = self.model.pi(zs, task) _, actions, log_probs, action_probs = self.model.pi(zs, task)
qs = self.model.Q(zs, pis, task, return_type='avg', detach=True) qs = self.model.Q(zs, actions, task, return_type='avg', detach=True)
self.scale.update(qs[0]) self.scale.update(qs[0])
qs = self.scale(qs) qs = self.scale(qs)
# Loss is a weighted sum of Q-values # Loss is a weighted sum of Q-values
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device)) rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean() if self.cfg.action == 'discrete':
pi_loss = ((action_probs * (self.cfg.entropy_coef * log_probs - qs)).mean(dim=(1,2)) * rho).mean()
else:
pi_loss = ((self.cfg.entropy_coef * log_probs - qs).mean(dim=(1,2)) * rho).mean()
pi_loss.backward() pi_loss.backward()
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm) pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
self.pi_optim.step() self.pi_optim.step()

View File

@@ -54,7 +54,8 @@ class OnlineTrainer(Trainer):
else: else:
obs = obs.unsqueeze(0).cpu() obs = obs.unsqueeze(0).cpu()
if action is None: if action is None:
action = torch.full_like(self.env.rand_act(), float('nan')) action_val = -1 if self.cfg.action == 'discrete' else float('nan')
action = torch.full_like(self.env.rand_act(), action_val)
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan'))
td = TensorDict( td = TensorDict(