Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e276e6aaa | ||
|
|
4f07f1ced4 | ||
|
|
4dcd933b8f | ||
|
|
d463268bd2 | ||
|
|
88ad0620ca | ||
|
|
8280b82d5c | ||
|
|
a9b5ad0ff8 | ||
|
|
dc6720d322 | ||
|
|
dee034070e |
@@ -12,6 +12,10 @@ Official implementation of
|
||||
|
||||
----
|
||||
|
||||
**Discrete branch:** this branch is under active development and contains experimental support for discrete action spaces. We expect a stable release to be available in a few months. Please use the `main` branch for the time being.
|
||||
|
||||
----
|
||||
|
||||
**Announcement: training just got ~4.5x faster!**
|
||||
|
||||
Expect **~4.5x** faster wall-time (depending on hardware and task) with the most recent release (Nov 10, 2024). A majority of the speedups in this branch are enabled with the additional flag `compile=true`. To run the code with `compile=true`, **you will need to install the latest `nightly` versions of PyTorch, TensorDict, and TorchRL**. See `docker/environment.yaml` for a tested configuration. `compile=true` is available in state-based online RL at the moment, and we expect to roll out support across all settings in the coming months. Thank you to [Vincent Moens](https://github.com/vmoens) who has been a key contributor to our torch.compile compatibility!
|
||||
|
||||
@@ -42,6 +42,16 @@ def squash(mu, pi, log_pi):
|
||||
return mu, pi, log_pi
|
||||
|
||||
|
||||
def int_to_one_hot(x, num_classes):
|
||||
"""
|
||||
Converts an integer tensor to a one-hot tensor.
|
||||
Supports batched inputs.
|
||||
"""
|
||||
one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
|
||||
one_hot.scatter_(-1, x.unsqueeze(-1), 1)
|
||||
return one_hot
|
||||
|
||||
|
||||
def symlog(x):
|
||||
"""
|
||||
Symmetric logarithmic function.
|
||||
|
||||
@@ -2,9 +2,12 @@ from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions.categorical import Categorical
|
||||
from tensordict.nn import TensorDictParams
|
||||
|
||||
from common import layers, math, init
|
||||
from tensordict.nn import TensorDictParams
|
||||
|
||||
|
||||
class WorldModel(nn.Module):
|
||||
"""
|
||||
@@ -23,7 +26,7 @@ class WorldModel(nn.Module):
|
||||
self._encoder = layers.enc(cfg)
|
||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim if cfg.action_space == 'continuous' else cfg.action_dim)
|
||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||
self.apply(init.weight_init)
|
||||
init.zero_([self._reward[-1].weight, self._Qs.params["2", "weight"]])
|
||||
@@ -121,15 +124,12 @@ class WorldModel(nn.Module):
|
||||
z = torch.cat([z, a], dim=-1)
|
||||
return self._reward(z)
|
||||
|
||||
def pi(self, z, task):
|
||||
def _continuous_pi(self, z, task):
|
||||
"""
|
||||
Samples an action from the policy prior.
|
||||
The policy prior is a Gaussian distribution with
|
||||
mean and (log) std predicted by a neural network.
|
||||
"""
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
|
||||
# Gaussian policy prior
|
||||
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
||||
log_std = math.log_std(log_std, self.log_std_min, self.log_std_dif)
|
||||
@@ -149,6 +149,41 @@ class WorldModel(nn.Module):
|
||||
|
||||
return mu, pi, log_pi, log_std
|
||||
|
||||
def _discrete_pi(self, z, task):
|
||||
"""
|
||||
Samples an action from the policy prior.
|
||||
The policy prior is a categorical distribution
|
||||
with logits predicted by a neural network.
|
||||
"""
|
||||
assert task is None, "Discrete policy does not support multitask."
|
||||
|
||||
# Categorical policy prior
|
||||
logits = self._pi(z)
|
||||
policy_dist = Categorical(logits=logits)
|
||||
|
||||
action = policy_dist.sample()
|
||||
action_probs = policy_dist.probs
|
||||
log_prob = F.log_softmax(logits, dim=-1)
|
||||
|
||||
one_hot_action = math.int_to_one_hot(action, self.cfg.action_dim)
|
||||
|
||||
return action, one_hot_action, log_prob, action_probs
|
||||
|
||||
def pi(self, z, task):
|
||||
"""
|
||||
Samples an action from the policy prior.
|
||||
Policy can be either continuous (Gaussian) or discrete (categorical).
|
||||
"""
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
|
||||
if self.cfg.action_space == 'discrete':
|
||||
return self._discrete_pi(z, task)
|
||||
elif self.cfg.action_space == 'continuous':
|
||||
return self._continuous_pi(z, task)
|
||||
else:
|
||||
raise NotImplementedError(f"Action space {self.cfg.action} not supported.")
|
||||
|
||||
def Q(self, z, a, task, return_type='min', target=False, detach=False):
|
||||
"""
|
||||
Predict state-action value.
|
||||
|
||||
@@ -2,7 +2,7 @@ defaults:
|
||||
- override hydra/launcher: submitit_local
|
||||
|
||||
# environment
|
||||
task: dog-run
|
||||
task: discrete-cartpole-swingup
|
||||
obs: state
|
||||
|
||||
# evaluation
|
||||
@@ -79,6 +79,7 @@ task_title: ???
|
||||
multitask: ???
|
||||
tasks: ???
|
||||
obs_shape: ???
|
||||
action_space: ???
|
||||
action_dim: ???
|
||||
episode_length: ???
|
||||
obs_shapes: ???
|
||||
|
||||
@@ -3,6 +3,7 @@ import warnings
|
||||
|
||||
import gym
|
||||
|
||||
from envs.wrappers.discrete import DiscreteWrapper
|
||||
from envs.wrappers.multitask import MultitaskWrapper
|
||||
from envs.wrappers.pixels import PixelWrapper
|
||||
from envs.wrappers.tensor import TensorWrapper
|
||||
@@ -59,24 +60,34 @@ def make_env(cfg):
|
||||
gym.logger.set_level(40)
|
||||
if cfg.multitask:
|
||||
env = make_multitask_env(cfg)
|
||||
|
||||
else:
|
||||
env = None
|
||||
if cfg.task.startswith('discrete-'):
|
||||
discrete = True
|
||||
cfg.task = cfg.task.replace('discrete-', '')
|
||||
else:
|
||||
discrete = False
|
||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
||||
try:
|
||||
env = fn(cfg)
|
||||
break
|
||||
except ValueError:
|
||||
pass
|
||||
if env is None:
|
||||
raise ValueError(f'Failed to make environment "{cfg.task}": please verify that dependencies are installed and that the task exists.')
|
||||
env = TensorWrapper(env)
|
||||
if discrete:
|
||||
env = DiscreteWrapper(env)
|
||||
if cfg.get('obs', 'state') == 'rgb':
|
||||
env = PixelWrapper(cfg, env)
|
||||
try: # Dict
|
||||
cfg.obs_shape = {k: v.shape for k, v in env.observation_space.spaces.items()}
|
||||
except: # Box
|
||||
cfg.obs_shape = {cfg.get('obs', 'state'): env.observation_space.shape}
|
||||
cfg.action_dim = env.action_space.shape[0]
|
||||
assert not isinstance(env.action_space, (gym.spaces.Dict, gym.spaces.MultiDiscrete)), \
|
||||
'Dict and MultiDiscrete action spaces are not supported.'
|
||||
cfg.action_space = 'discrete' if isinstance(env.action_space, gym.spaces.Discrete) else 'continuous'
|
||||
cfg.action_dim = env.action_space.n if cfg.action_space == 'discrete' else env.action_space.shape[0]
|
||||
cfg.episode_length = env.max_episode_steps
|
||||
cfg.seed_steps = max(1000, 5*cfg.episode_length)
|
||||
return env
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections import deque, defaultdict
|
||||
from collections import defaultdict
|
||||
from typing import Any, NamedTuple
|
||||
import dm_env
|
||||
import numpy as np
|
||||
|
||||
33
tdmpc2/envs/wrappers/discrete.py
Normal file
33
tdmpc2/envs/wrappers/discrete.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import gym
|
||||
import torch
|
||||
|
||||
from common import math
|
||||
|
||||
|
||||
class DiscreteWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for converting continuous action spaces to discrete via binning.
|
||||
"""
|
||||
|
||||
def __init__(self, env, bins_per_dim=5):
|
||||
super().__init__(env)
|
||||
self.bins_per_dim = bins_per_dim
|
||||
self.continuous_dims = self.env.action_space.shape[0]
|
||||
# Equally spaced bins along each dimension
|
||||
self.action_space = gym.spaces.Discrete(bins_per_dim ** self.continuous_dims)
|
||||
|
||||
def rand_act(self):
|
||||
action = torch.tensor(self.action_space.sample(), dtype=torch.int64)
|
||||
return math.int_to_one_hot(action, self.action_space.n)
|
||||
|
||||
def _discrete_to_continuous(self, action):
|
||||
# Convert a discrete action to a continuous action
|
||||
action = torch.argmax(action)
|
||||
action = action.item()
|
||||
action = [action // self.bins_per_dim ** i % self.bins_per_dim for i in range(self.continuous_dims)]
|
||||
action = torch.tensor(action, dtype=torch.float32)
|
||||
return (action - 1) / 1
|
||||
|
||||
def step(self, action):
|
||||
action = self._discrete_to_continuous(action)
|
||||
return self.env.step(action)
|
||||
137
tdmpc2/tdmpc2.py
137
tdmpc2/tdmpc2.py
@@ -103,11 +103,12 @@ class TDMPC2(torch.nn.Module):
|
||||
if task is not None:
|
||||
task = torch.tensor([task], device=self.device)
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
|
||||
action = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
|
||||
else:
|
||||
z = self.model.encode(obs, task)
|
||||
a = self.model.pi(z, task)[int(not eval_mode)][0]
|
||||
return a.cpu()
|
||||
select_idx = int(not eval_mode or self.cfg.action_space == 'discrete')
|
||||
action = self.model.pi(z, task)[select_idx][0]
|
||||
return action.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def _estimate_value(self, z, actions, task):
|
||||
@@ -119,7 +120,37 @@ class TDMPC2(torch.nn.Module):
|
||||
G = G + discount * reward
|
||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||
discount = discount * discount_update
|
||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||
pi = self.model.pi(z, task)[1]
|
||||
if self.cfg.action_space == 'discrete':
|
||||
pi = pi.squeeze(1) # TODO: this is a bit hacky
|
||||
return G + discount * self.model.Q(z, pi, task, return_type='avg')
|
||||
|
||||
@torch.no_grad()
|
||||
def _sample_policy(self, z, task):
|
||||
"""Sample trajectories from the policy prior."""
|
||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||
for t in range(self.cfg.horizon-1):
|
||||
action = self.model.pi(z, task)[1]
|
||||
if self.cfg.action_space == 'discrete':
|
||||
action = action.squeeze(1)
|
||||
pi_actions[t] = action
|
||||
z = self.model.next(z, pi_actions[t], task)
|
||||
action = self.model.pi(z, task)[1]
|
||||
if self.cfg.action_space == 'discrete':
|
||||
action = action.squeeze(1)
|
||||
pi_actions[-1] = action
|
||||
return pi_actions
|
||||
|
||||
@torch.no_grad()
|
||||
def _sample_actions(self, n, mean=None, std=None):
|
||||
"""Sample actions from a Gaussian or Categorical distribution."""
|
||||
if self.cfg.action_space == 'discrete':
|
||||
actions = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, n), device=self.device)
|
||||
actions = math.int_to_one_hot(actions, self.cfg.action_dim)
|
||||
else:
|
||||
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
|
||||
actions = (mean.unsqueeze(1) + std.unsqueeze(1) * r).clamp(-1, 1)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||
@@ -135,61 +166,61 @@ class TDMPC2(torch.nn.Module):
|
||||
Returns:
|
||||
torch.Tensor: Action to take in the environment.
|
||||
"""
|
||||
# Sample policy trajectories
|
||||
# Encode observation
|
||||
z = self.model.encode(obs, task)
|
||||
if self.cfg.num_pi_trajs > 0:
|
||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
||||
for t in range(self.cfg.horizon-1):
|
||||
pi_actions[t] = self.model.pi(_z, task)[1]
|
||||
_z = self.model.next(_z, pi_actions[t], task)
|
||||
pi_actions[-1] = self.model.pi(_z, task)[1]
|
||||
|
||||
# Initialize state and parameters
|
||||
z = z.repeat(self.cfg.num_samples, 1)
|
||||
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
|
||||
if not t0:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
|
||||
# Initialize parameters
|
||||
if self.cfg.action_space == 'continuous':
|
||||
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
|
||||
if not t0:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
else:
|
||||
mean, std = None, None
|
||||
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
||||
|
||||
# Sample policy trajectories
|
||||
if self.cfg.num_pi_trajs > 0:
|
||||
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
||||
actions[:, :self.cfg.num_pi_trajs] = self._sample_policy(z[:self.cfg.num_pi_trajs], task)
|
||||
|
||||
# Iterate MPPI
|
||||
for _ in range(self.cfg.iterations):
|
||||
|
||||
# Sample actions
|
||||
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
|
||||
actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r
|
||||
actions_sample = actions_sample.clamp(-1, 1)
|
||||
actions[:, self.cfg.num_pi_trajs:] = actions_sample
|
||||
# Sample random actions
|
||||
actions[:, self.cfg.num_pi_trajs:] = self._sample_actions(self.cfg.num_samples-self.cfg.num_pi_trajs, mean, std)
|
||||
if self.cfg.multitask:
|
||||
actions = actions * self.model._action_masks[task]
|
||||
|
||||
# Compute elite actions
|
||||
|
||||
# Select elites and compute scores
|
||||
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
|
||||
# Update parameters
|
||||
max_value = elite_value.max(0).values
|
||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
||||
score = score / score.sum(0)
|
||||
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
|
||||
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
|
||||
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
|
||||
if self.cfg.multitask:
|
||||
mean = mean * self.model._action_masks[task]
|
||||
std = std * self.model._action_masks[task]
|
||||
|
||||
# Update parameters
|
||||
if self.cfg.action_space == 'continuous':
|
||||
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
|
||||
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
|
||||
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
|
||||
if self.cfg.multitask:
|
||||
mean = mean * self.model._action_masks[task]
|
||||
std = std * self.model._action_masks[task]
|
||||
else:
|
||||
break
|
||||
|
||||
# Select action
|
||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
||||
a, std = actions[0], std[0]
|
||||
if not eval_mode:
|
||||
a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
|
||||
self._prev_mean.copy_(mean)
|
||||
return a.clamp(-1, 1)
|
||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
||||
action = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)[0]
|
||||
if self.cfg.action_space == 'continuous':
|
||||
if not eval_mode:
|
||||
action = action + std[0] * torch.randn(self.cfg.action_dim, device=std.device)
|
||||
self._prev_mean.copy_(mean)
|
||||
action = action.clamp(-1, 1)
|
||||
|
||||
return action
|
||||
|
||||
def update_pi(self, zs, task):
|
||||
"""
|
||||
@@ -202,14 +233,28 @@ class TDMPC2(torch.nn.Module):
|
||||
Returns:
|
||||
float: Loss of the policy update.
|
||||
"""
|
||||
_, pis, log_pis, _ = self.model.pi(zs, task)
|
||||
qs = self.model.Q(zs, pis, task, return_type='avg', detach=True)
|
||||
self.scale.update(qs[0])
|
||||
_, actions, log_probs, action_probs = self.model.pi(zs, task)
|
||||
|
||||
if self.cfg.action_space == 'discrete':
|
||||
actions = torch.eye(self.cfg.action_dim, device=zs.device).unsqueeze(0)
|
||||
zs = zs.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)
|
||||
actions = actions.unsqueeze(0).repeat(zs.shape[0], zs.shape[1], 1, 1)
|
||||
|
||||
qs = self.model.Q(zs, actions, task, return_type='avg', detach=True)
|
||||
|
||||
if self.cfg.action_space == 'discrete':
|
||||
qs = qs.squeeze(-1)
|
||||
self.scale.update(torch.sum(action_probs*qs,dim=(1,2),keepdim=True)[0])
|
||||
else:
|
||||
self.scale.update(qs[0])
|
||||
qs = self.scale(qs)
|
||||
|
||||
# Loss is a weighted sum of Q-values
|
||||
rho = torch.pow(self.cfg.rho, torch.arange(len(qs), device=self.device))
|
||||
pi_loss = ((self.cfg.entropy_coef * log_pis - qs).mean(dim=(1,2)) * rho).mean()
|
||||
if self.cfg.action_space == 'discrete':
|
||||
pi_loss = ((action_probs * (self.cfg.entropy_coef * log_probs - qs)).mean(dim=(1,2)) * rho).mean()
|
||||
else:
|
||||
pi_loss = ((self.cfg.entropy_coef * log_probs - qs).mean(dim=(1,2)) * rho).mean()
|
||||
pi_loss.backward()
|
||||
pi_grad_norm = torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm)
|
||||
self.pi_optim.step()
|
||||
@@ -231,6 +276,8 @@ class TDMPC2(torch.nn.Module):
|
||||
torch.Tensor: TD-target.
|
||||
"""
|
||||
pi = self.model.pi(next_z, task)[1]
|
||||
if self.cfg.action_space == 'discrete':
|
||||
pi = pi.squeeze(2) # TODO: this is a bit hacky
|
||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
||||
|
||||
|
||||
@@ -54,7 +54,8 @@ class OnlineTrainer(Trainer):
|
||||
else:
|
||||
obs = obs.unsqueeze(0).cpu()
|
||||
if action is None:
|
||||
action = torch.full_like(self.env.rand_act(), float('nan'))
|
||||
action_val = -1 if self.cfg.action_space == 'discrete' else float('nan')
|
||||
action = torch.full_like(self.env.rand_act(), action_val)
|
||||
if reward is None:
|
||||
reward = torch.tensor(float('nan'))
|
||||
td = TensorDict(
|
||||
@@ -97,6 +98,11 @@ class OnlineTrainer(Trainer):
|
||||
action = self.agent.act(obs, t0=len(self._tds)==1)
|
||||
else:
|
||||
action = self.env.rand_act()
|
||||
if self.cfg.action_space == 'discrete':
|
||||
# exploration schedule
|
||||
# minimum 0.01, maximum 0.05, anneal over 20k steps
|
||||
if torch.rand(1) < 0.01 + (0.05 - 0.01) * min(1, self._step / 20000):
|
||||
action = self.env.rand_act()
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._tds.append(self.to_td(obs, action, reward))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user