QoL improvements to termination signal debugging
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
name: episodic
|
name: tdmpc2
|
||||||
channels:
|
channels:
|
||||||
- pytorch-nightly
|
- pytorch-nightly
|
||||||
- nvidia
|
- nvidia
|
||||||
@@ -55,3 +55,7 @@ dependencies:
|
|||||||
# MyoSuite:
|
# MyoSuite:
|
||||||
# - myosuite
|
# - myosuite
|
||||||
####################
|
####################
|
||||||
|
# Classic MuJoCo/Box2d:
|
||||||
|
# - swig
|
||||||
|
# - gymnasium[box2d]
|
||||||
|
####################
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class WorldModel(nn.Module):
|
|||||||
self._encoder = layers.enc(cfg)
|
self._encoder = layers.enc(cfg)
|
||||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||||
self._terminated = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
|
self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
|
||||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||||
self.apply(init.weight_init)
|
self.apply(init.weight_init)
|
||||||
@@ -54,8 +54,8 @@ class WorldModel(nn.Module):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
repr = 'TD-MPC2 World Model\n'
|
repr = 'TD-MPC2 World Model\n'
|
||||||
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions']
|
modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
|
||||||
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]):
|
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
|
||||||
repr += f"{modules[i]}: {m}\n"
|
repr += f"{modules[i]}: {m}\n"
|
||||||
repr += "Learnable parameters: {:,}".format(self.total_params)
|
repr += "Learnable parameters: {:,}".format(self.total_params)
|
||||||
return repr
|
return repr
|
||||||
@@ -127,14 +127,14 @@ class WorldModel(nn.Module):
|
|||||||
z = torch.cat([z, a], dim=-1)
|
z = torch.cat([z, a], dim=-1)
|
||||||
return self._reward(z)
|
return self._reward(z)
|
||||||
|
|
||||||
def terminated(self, z, task):
|
def termination(self, z, task):
|
||||||
"""
|
"""
|
||||||
Predicts termination signal.
|
Predicts termination signal.
|
||||||
"""
|
"""
|
||||||
assert task is None
|
assert task is None
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
z = self.task_emb(z, task)
|
z = self.task_emb(z, task)
|
||||||
return torch.sigmoid(self._terminated(z))
|
return torch.sigmoid(self._termination(z))
|
||||||
|
|
||||||
def pi(self, z, task):
|
def pi(self, z, task):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ steps: 10_000_000
|
|||||||
batch_size: 256
|
batch_size: 256
|
||||||
reward_coef: 0.1
|
reward_coef: 0.1
|
||||||
value_coef: 0.1
|
value_coef: 0.1
|
||||||
terminated_coef: 0.1
|
termination_coef: 1
|
||||||
consistency_coef: 20
|
consistency_coef: 20
|
||||||
rho: 0.5
|
rho: 0.5
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
@@ -90,4 +90,4 @@ seed_steps: ???
|
|||||||
bin_size: ???
|
bin_size: ???
|
||||||
|
|
||||||
# speedups
|
# speedups
|
||||||
compile: False
|
compile: false
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ from envs.wrappers.tensor import TensorWrapper
|
|||||||
def missing_dependencies(task):
|
def missing_dependencies(task):
|
||||||
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
|
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
|
||||||
|
|
||||||
# try:
|
try:
|
||||||
from envs.dmcontrol import make_env as make_dm_control_env
|
from envs.dmcontrol import make_env as make_dm_control_env
|
||||||
# except:
|
except:
|
||||||
# make_dm_control_env = missing_dependencies
|
make_dm_control_env = missing_dependencies
|
||||||
try:
|
try:
|
||||||
from envs.maniskill import make_env as make_maniskill_env
|
from envs.maniskill import make_env as make_maniskill_env
|
||||||
except:
|
except:
|
||||||
@@ -25,6 +25,10 @@ try:
|
|||||||
from envs.myosuite import make_env as make_myosuite_env
|
from envs.myosuite import make_env as make_myosuite_env
|
||||||
except:
|
except:
|
||||||
make_myosuite_env = missing_dependencies
|
make_myosuite_env = missing_dependencies
|
||||||
|
try:
|
||||||
|
from envs.mujoco import make_env as make_mujoco_env
|
||||||
|
except:
|
||||||
|
make_mujoco_env = missing_dependencies
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
@@ -61,7 +65,7 @@ def make_env(cfg):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
env = None
|
env = None
|
||||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]:
|
||||||
try:
|
try:
|
||||||
env = fn(cfg)
|
env = fn(cfg)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
52
tdmpc2/envs/mujoco.py
Normal file
52
tdmpc2/envs/mujoco.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import numpy as np
|
||||||
|
import gymnasium as gym
|
||||||
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
|
|
||||||
|
MUJOCO_TASKS = {
|
||||||
|
'mujoco-halfcheetah': 'HalfCheetah-v4',
|
||||||
|
'lunarlander-continuous': 'LunarLander-v2',
|
||||||
|
}
|
||||||
|
|
||||||
|
class MuJoCoWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, cfg):
|
||||||
|
super().__init__(env)
|
||||||
|
self.env = env
|
||||||
|
self.cfg = cfg
|
||||||
|
self._cumulative_reward = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._cumulative_reward = 0
|
||||||
|
return self.env.reset()[0]
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action.copy())
|
||||||
|
self._cumulative_reward += reward
|
||||||
|
done = terminated or truncated
|
||||||
|
info['terminated'] = terminated
|
||||||
|
if self.cfg.task == 'lunarlander-continuous':
|
||||||
|
info['success'] = self._cumulative_reward > 200
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unwrapped(self):
|
||||||
|
return self.env.unwrapped
|
||||||
|
|
||||||
|
def render(self, **kwargs):
|
||||||
|
return self.env.render(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def make_env(cfg):
|
||||||
|
"""
|
||||||
|
Make classic/MuJoCo environment.
|
||||||
|
"""
|
||||||
|
if not cfg.task in MUJOCO_TASKS:
|
||||||
|
raise ValueError('Unknown task:', cfg.task)
|
||||||
|
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||||
|
if cfg.task == 'lunarlander-continuous':
|
||||||
|
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
|
||||||
|
else:
|
||||||
|
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
|
||||||
|
env = MuJoCoWrapper(env, cfg)
|
||||||
|
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
|
||||||
|
return env
|
||||||
@@ -24,7 +24,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||||
{'params': self.model._dynamics.parameters()},
|
{'params': self.model._dynamics.parameters()},
|
||||||
{'params': self.model._reward.parameters()},
|
{'params': self.model._reward.parameters()},
|
||||||
{'params': self.model._terminated.parameters()},
|
{'params': self.model._termination.parameters()},
|
||||||
{'params': self.model._Qs.parameters()},
|
{'params': self.model._Qs.parameters()},
|
||||||
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
|
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
|
||||||
}
|
}
|
||||||
@@ -36,6 +36,8 @@ class TDMPC2(torch.nn.Module):
|
|||||||
self.discount = torch.tensor(
|
self.discount = torch.tensor(
|
||||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||||
|
print('Episode length:', cfg.episode_length)
|
||||||
|
print('Discount factor:', self.discount)
|
||||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||||
if cfg.compile:
|
if cfg.compile:
|
||||||
print('Compiling update function with torch.compile...')
|
print('Compiling update function with torch.compile...')
|
||||||
@@ -122,17 +124,17 @@ class TDMPC2(torch.nn.Module):
|
|||||||
def _estimate_value(self, z, actions, task):
|
def _estimate_value(self, z, actions, task):
|
||||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||||
G, discount = 0, 1
|
G, discount = 0, 1
|
||||||
terminated = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
|
termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
|
||||||
for t in range(self.cfg.horizon):
|
for t in range(self.cfg.horizon):
|
||||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||||
z = self.model.next(z, actions[t], task)
|
z = self.model.next(z, actions[t], task)
|
||||||
|
|
||||||
G = G + discount * (1-terminated) * reward
|
G = G + discount * (1-termination) * reward
|
||||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||||
discount = discount * discount_update
|
discount = discount * discount_update
|
||||||
terminated = torch.clip(terminated + (self.model.terminated(z, task) > 0.5).float(), max=1.)
|
termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.)
|
||||||
action, _ = self.model.pi(z, task)
|
action, _ = self.model.pi(z, task)
|
||||||
return G + discount * (1-terminated) * self.model.Q(z, action, task, return_type='avg')
|
return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||||
@@ -278,7 +280,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
_zs = zs[:-1]
|
_zs = zs[:-1]
|
||||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||||
reward_preds = self.model.reward(_zs, action, task)
|
reward_preds = self.model.reward(_zs, action, task)
|
||||||
terminated_pred = self.model.terminated(zs[-1], task)
|
termination_pred = self.model.termination(zs[-1], task)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
reward_loss, value_loss = 0, 0
|
reward_loss, value_loss = 0, 0
|
||||||
@@ -289,12 +291,12 @@ class TDMPC2(torch.nn.Module):
|
|||||||
|
|
||||||
consistency_loss = consistency_loss / self.cfg.horizon
|
consistency_loss = consistency_loss / self.cfg.horizon
|
||||||
reward_loss = reward_loss / self.cfg.horizon
|
reward_loss = reward_loss / self.cfg.horizon
|
||||||
terminated_loss = F.binary_cross_entropy(terminated_pred, terminated[-1])
|
termination_loss = F.binary_cross_entropy(termination_pred, terminated[-1])
|
||||||
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
||||||
total_loss = (
|
total_loss = (
|
||||||
self.cfg.consistency_coef * consistency_loss +
|
self.cfg.consistency_coef * consistency_loss +
|
||||||
self.cfg.reward_coef * reward_loss +
|
self.cfg.reward_coef * reward_loss +
|
||||||
self.cfg.terminated_coef * terminated_loss +
|
self.cfg.termination_coef * termination_loss +
|
||||||
self.cfg.value_coef * value_loss
|
self.cfg.value_coef * value_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -316,6 +318,9 @@ class TDMPC2(torch.nn.Module):
|
|||||||
"consistency_loss": consistency_loss,
|
"consistency_loss": consistency_loss,
|
||||||
"reward_loss": reward_loss,
|
"reward_loss": reward_loss,
|
||||||
"value_loss": value_loss,
|
"value_loss": value_loss,
|
||||||
|
"termination_loss": termination_loss,
|
||||||
|
"termination_mean": termination_pred.mean(),
|
||||||
|
"termination_mean_gt": terminated[-1].mean(),
|
||||||
"total_loss": total_loss,
|
"total_loss": total_loss,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -48,9 +48,8 @@ def train(cfg: dict):
|
|||||||
cfg = parse_cfg(cfg)
|
cfg = parse_cfg(cfg)
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||||
|
assert cfg.episodic, \
|
||||||
assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \
|
f'This branch is experimental and only supports episodic RL tasks at this time.'
|
||||||
f'This branch is experimental and only supports cartpole-balance-sparse at this time.'
|
|
||||||
|
|
||||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
|
|||||||
@@ -87,6 +87,8 @@ class OnlineTrainer(Trainer):
|
|||||||
train_metrics.update(
|
train_metrics.update(
|
||||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||||
episode_success=info['success'],
|
episode_success=info['success'],
|
||||||
|
episode_length=len(self._tds),
|
||||||
|
episode_terminated=info['terminated'],
|
||||||
)
|
)
|
||||||
train_metrics.update(self.common_metrics())
|
train_metrics.update(self.common_metrics())
|
||||||
self.logger.log(train_metrics, 'train')
|
self.logger.log(train_metrics, 'train')
|
||||||
|
|||||||
Reference in New Issue
Block a user