init
This commit is contained in:
@@ -42,6 +42,16 @@ def squash(mu, pi, log_pi):
|
|||||||
return mu, pi, log_pi
|
return mu, pi, log_pi
|
||||||
|
|
||||||
|
|
||||||
|
def int_to_one_hot(x, num_classes):
|
||||||
|
"""
|
||||||
|
Converts an integer tensor to a one-hot tensor.
|
||||||
|
Supports batched inputs.
|
||||||
|
"""
|
||||||
|
one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
|
||||||
|
one_hot.scatter_(-1, x.unsqueeze(-1), 1)
|
||||||
|
return one_hot
|
||||||
|
|
||||||
|
|
||||||
def symlog(x):
|
def symlog(x):
|
||||||
"""
|
"""
|
||||||
Symmetric logarithmic function.
|
Symmetric logarithmic function.
|
||||||
|
|||||||
@@ -77,6 +77,10 @@ def parse_cfg(cfg: OmegaConf) -> OmegaConf:
|
|||||||
cfg.task_dim = 0
|
cfg.task_dim = 0
|
||||||
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
cfg.tasks = TASK_SET.get(cfg.task, [cfg.task])
|
||||||
|
|
||||||
|
# Check action space compatibility
|
||||||
|
if cfg.get('action', 'continuous') == 'discrete':
|
||||||
|
assert not cfg.multitask, 'Discrete actions are not supported in multi-task settings.'
|
||||||
|
|
||||||
# Check torch.compile compatibility
|
# Check torch.compile compatibility
|
||||||
if cfg.get('compile', False):
|
if cfg.get('compile', False):
|
||||||
assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.'
|
assert cfg.obs == 'state', 'torch.compile only supports state observations at the moment.'
|
||||||
|
|||||||
@@ -2,9 +2,12 @@ from copy import deepcopy
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributions.categorical import Categorical
|
||||||
|
from tensordict.nn import TensorDictParams
|
||||||
|
|
||||||
from common import layers, math, init
|
from common import layers, math, init
|
||||||
from tensordict.nn import TensorDictParams
|
|
||||||
|
|
||||||
class WorldModel(nn.Module):
|
class WorldModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -23,7 +26,7 @@ class WorldModel(nn.Module):
|
|||||||
self._encoder = layers.enc(cfg)
|
self._encoder = layers.enc(cfg)
|
||||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action == 'continuous' else cfg.action_dim)
|
||||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||||
self.apply(init.weight_init)
|
self.apply(init.weight_init)
|
||||||
init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
|
init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
|
||||||
@@ -121,15 +124,12 @@ class WorldModel(nn.Module):
|
|||||||
z = torch.cat([z, a], dim=-1)
|
z = torch.cat([z, a], dim=-1)
|
||||||
return self._reward(z)
|
return self._reward(z)
|
||||||
|
|
||||||
def pi(self, z, task):
|
def _continuous_pi(self, z, task):
|
||||||
"""
|
"""
|
||||||
Samples an action from the policy prior.
|
Samples an action from the policy prior.
|
||||||
The policy prior is a Gaussian distribution with
|
The policy prior is a Gaussian distribution with
|
||||||
mean and (log) std predicted by a neural network.
|
mean and (log) std predicted by a neural network.
|
||||||
"""
|
"""
|
||||||
if self.cfg.multitask:
|
|
||||||
z = self.task_emb(z, task)
|
|
||||||
|
|
||||||
# Gaussian policy prior
|
# Gaussian policy prior
|
||||||
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
||||||
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
|
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
|
||||||
@@ -149,6 +149,41 @@ class WorldModel(nn.Module):
|
|||||||
|
|
||||||
return mu, pi, log_pi, log_std
|
return mu, pi, log_pi, log_std
|
||||||
|
|
||||||
|
def _discrete_pi(self, z, task):
|
||||||
|
"""
|
||||||
|
Samples an action from the policy prior.
|
||||||
|
The policy prior is a categorical distribution
|
||||||
|
with logits predicted by a neural network.
|
||||||
|
"""
|
||||||
|
# Categorical policy prior
|
||||||
|
logits = self._pi(z)
|
||||||
|
policy_dist = Categorical(logits=logits)
|
||||||
|
action = policy_dist.sample()
|
||||||
|
action = math.int_to_one_hot(action, self.cfg.action_dim)
|
||||||
|
|
||||||
|
# Action probabilities for calculating the adapted soft-Q loss
|
||||||
|
action_probs = policy_dist.probs
|
||||||
|
log_prob = F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
return action, action, log_prob, action_probs
|
||||||
|
|
||||||
|
|
||||||
|
def pi(self, z, task):
|
||||||
|
"""
|
||||||
|
Samples an action from the policy prior.
|
||||||
|
Policy can be either continuous (Gaussian) or discrete (categorical).
|
||||||
|
"""
|
||||||
|
if self.cfg.multitask:
|
||||||
|
z = self.task_emb(z, task)
|
||||||
|
|
||||||
|
if self.cfg.action == 'discrete':
|
||||||
|
return self._discrete_pi(z, task)
|
||||||
|
elif self.cfg.action == 'continuous':
|
||||||
|
return self._continuous_pi(z, task)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Action space {self.cfg.action} not supported.")
|
||||||
|
|
||||||
|
|
||||||
def Q(self, z, a, task, return_type='min', target=False, detach=False):
|
def Q(self, z, a, task, return_type='min', target=False, detach=False):
|
||||||
"""
|
"""
|
||||||
Predict state-action value.
|
Predict state-action value.
|
||||||
|
|||||||
@@ -2,8 +2,9 @@ defaults:
|
|||||||
- override hydra/launcher: submitit_local
|
- override hydra/launcher: submitit_local
|
||||||
|
|
||||||
# environment
|
# environment
|
||||||
task: dog-run
|
task: cartpole-swingup
|
||||||
obs: state
|
obs: state
|
||||||
|
action: discrete
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
checkpoint: ???
|
checkpoint: ???
|
||||||
@@ -29,7 +30,7 @@ exp_name: default
|
|||||||
data_dir: ???
|
data_dir: ???
|
||||||
|
|
||||||
# planning
|
# planning
|
||||||
mpc: true
|
mpc: false
|
||||||
iterations: 6
|
iterations: 6
|
||||||
num_samples: 512
|
num_samples: 512
|
||||||
num_elites: 64
|
num_elites: 64
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import warnings
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
|
from envs.wrappers.discrete import DiscreteWrapper
|
||||||
from envs.wrappers.multitask import MultitaskWrapper
|
from envs.wrappers.multitask import MultitaskWrapper
|
||||||
from envs.wrappers.pixels import PixelWrapper
|
from envs.wrappers.pixels import PixelWrapper
|
||||||
from envs.wrappers.tensor import TensorWrapper
|
from envs.wrappers.tensor import TensorWrapper
|
||||||
@@ -65,6 +66,7 @@ def make_env(cfg):
|
|||||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
||||||
try:
|
try:
|
||||||
env = fn(cfg)
|
env = fn(cfg)
|
||||||
|
break
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
if env is None:
|
if env is None:
|
||||||
@@ -72,11 +74,13 @@ def make_env(cfg):
|
|||||||
env = TensorWrapper(env)
|
env = TensorWrapper(env)
|
||||||
if cfg.get('obs', 'state') == 'rgb':
|
if cfg.get('obs', 'state') == 'rgb':
|
||||||
env = PixelWrapper(cfg, env)
|
env = PixelWrapper(cfg, env)
|
||||||
|
if cfg.get('action', 'discrete'):
|
||||||
|
env = DiscreteWrapper(env)
|
||||||
try: # Dict
|
try: # Dict
|
||||||
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
||||||
except: # Box
|
except: # Box
|
||||||
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
|
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
|
||||||
cfg.action_dim = env.action_space.shape[0]
|
cfg.action_dim = env.action_space.n if cfg.action == 'discrete' else env.action_space.shape[0]
|
||||||
cfg.episode_length = env.max_episode_steps
|
cfg.episode_length = env.max_episode_steps
|
||||||
cfg.seed_steps = max(1000, 5*cfg.episode_length)
|
cfg.seed_steps = max(1000, 5*cfg.episode_length)
|
||||||
return env
|
return env
|
||||||
|
|||||||
35
tdmpc2/envs/wrappers/discrete.py
Normal file
35
tdmpc2/envs/wrappers/discrete.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from common import math
|
||||||
|
|
||||||
|
|
||||||
|
class DiscreteWrapper(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Wrapper for converting continuous action spaces to discrete via binning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env):
|
||||||
|
super().__init__(env)
|
||||||
|
self.continuous_dims = self.env.action_space.shape[0]
|
||||||
|
# Bins at [-1, 0, 1] for each dimension
|
||||||
|
# Discrete actions include all possible combinations of these bins
|
||||||
|
self.action_space = gym.spaces.Discrete(3 ** self.continuous_dims)
|
||||||
|
|
||||||
|
def rand_act(self):
|
||||||
|
action = torch.tensor(self.action_space.sample(), dtype=torch.int64)
|
||||||
|
return math.int_to_one_hot(action, self.action_space.n)
|
||||||
|
|
||||||
|
def _discrete_to_continuous(self, action):
|
||||||
|
# Convert a discrete action to a continuous action
|
||||||
|
# action is a one-hot encoded tensor
|
||||||
|
action = torch.argmax(action)
|
||||||
|
action = action.item()
|
||||||
|
action = [action // 3 ** i % 3 for i in range(self.continuous_dims)]
|
||||||
|
action = torch.tensor(action, dtype=torch.float32)
|
||||||
|
return (action - 1) / 1
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
action = self._discrete_to_continuous(action)
|
||||||
|
return self.env.step(action)
|
||||||
@@ -103,11 +103,11 @@ class TDMPC2(torch.nn.Module):
|
|||||||
if task is not None:
|
if task is not None:
|
||||||
task = torch.tensor([task], device=self.device)
|
task = torch.tensor([task], device=self.device)
|
||||||
if self.cfg.mpc:
|
if self.cfg.mpc:
|
||||||
a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
|
action = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
|
||||||
else:
|
else:
|
||||||
z = self.model.encode(obs, task)
|
z = self.model.encode(obs, task)
|
||||||
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
action = self.model.pi(z, task)[int(not eval_mode)][0]
|
||||||
return a.cpu()
|
return action.cpu()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _estimate_value(self, z, actions, task):
|
def _estimate_value(self, z, actions, task):
|
||||||
@@ -202,14 +202,17 @@ class TDMPC2(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
float: Loss of the policy update.
|
float: Loss of the policy update.
|
||||||
"""
|
"""
|
||||||
_, pis, log_pis, _ = self.model.pi(zs, task)
|
_, actions, log_probs, action_probs = self.model.pi(zs, task)
|
||||||
qs = self.model.Q(zs, pis, task, return_type='avg', detach=True)
|
qs = self.model.Q(zs, actions, task, return_type='avg', detach=True)
|
||||||
self.scale.update(qs[0])
|
self.scale.update(qs[0])
|
||||||
qs = self.scale(qs)
|
qs = self.scale(qs)
|
||||||
|
|
||||||
# Loss is a weighted sum of Q-values
|
# Loss is a weighted sum of Q-values
|
||||||
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||||
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
|
if self.cfg.action == 'discrete':
|
||||||
|
pi_loss = ((action_probs * (self.cfg.entropy_coef * log_probs - qs)).mean(dim=(1,2)) * rho).mean()
|
||||||
|
else:
|
||||||
|
pi_loss = ((self.cfg.entropy_coef * log_probs - qs).mean(dim=(1,2)) * rho).mean()
|
||||||
pi_loss.backward()
|
pi_loss.backward()
|
||||||
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
||||||
self.pi_optim.step()
|
self.pi_optim.step()
|
||||||
|
|||||||
@@ -54,7 +54,8 @@ class OnlineTrainer(Trainer):
|
|||||||
else:
|
else:
|
||||||
obs = obs.unsqueeze(0).cpu()
|
obs = obs.unsqueeze(0).cpu()
|
||||||
if action is None:
|
if action is None:
|
||||||
action = torch.full_like(self.env.rand_act(), float('nan'))
|
action_val = -1 if self.cfg.action == 'discrete' else float('nan')
|
||||||
|
action = torch.full_like(self.env.rand_act(), action_val)
|
||||||
if reward is None:
|
if reward is None:
|
||||||
reward = torch.tensor(float('nan'))
|
reward = torch.tensor(float('nan'))
|
||||||
td = TensorDict(
|
td = TensorDict(
|
||||||
|
|||||||
Reference in New Issue
Block a user