QoL improvements to termination signal debugging

This commit is contained in:
Nicklas Hansen
2025-04-08 19:15:31 -07:00
parent add30b5a74
commit 81eb17068e
8 changed files with 90 additions and 24 deletions

View File

@@ -1,4 +1,4 @@
name: episodic
name: tdmpc2
channels:
- pytorch-nightly
- nvidia
@@ -55,3 +55,7 @@ dependencies:
# MyoSuite:
# - myosuite
####################
# Classic MuJoCo/Box2d:
# - swig
# - gymnasium[box2d]
####################

View File

@@ -25,7 +25,7 @@ class WorldModel(nn.Module):
self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._terminated = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init)
@@ -54,8 +54,8 @@ class WorldModel(nn.Module):
def __repr__(self):
repr = 'TD-MPC2 World Model\n'
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions']
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]):
modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
repr += f"{modules[i]}: {m}\n"
repr += "Learnable parameters: {:,}".format(self.total_params)
return repr
@@ -127,14 +127,14 @@ class WorldModel(nn.Module):
z = torch.cat([z, a], dim=-1)
return self._reward(z)
def terminated(self, z, task):
def termination(self, z, task):
"""
Predicts termination signal.
"""
assert task is None
if self.cfg.multitask:
z = self.task_emb(z, task)
return torch.sigmoid(self._terminated(z))
return torch.sigmoid(self._termination(z))
def pi(self, z, task):
"""

View File

@@ -16,7 +16,7 @@ steps: 10_000_000
batch_size: 256
reward_coef: 0.1
value_coef: 0.1
terminated_coef: 0.1
termination_coef: 1
consistency_coef: 20
rho: 0.5
lr: 3e-4
@@ -90,4 +90,4 @@ seed_steps: ???
bin_size: ???
# speedups
compile: False
compile: false

View File

@@ -9,10 +9,10 @@ from envs.wrappers.tensor import TensorWrapper
def missing_dependencies(task):
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
# try:
from envs.dmcontrol import make_env as make_dm_control_env
# except:
# make_dm_control_env = missing_dependencies
try:
from envs.dmcontrol import make_env as make_dm_control_env
except:
make_dm_control_env = missing_dependencies
try:
from envs.maniskill import make_env as make_maniskill_env
except:
@@ -25,6 +25,10 @@ try:
from envs.myosuite import make_env as make_myosuite_env
except:
make_myosuite_env = missing_dependencies
try:
from envs.mujoco import make_env as make_mujoco_env
except:
make_mujoco_env = missing_dependencies
warnings.filterwarnings('ignore', category=DeprecationWarning)
@@ -61,7 +65,7 @@ def make_env(cfg):
else:
env = None
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]:
try:
env = fn(cfg)
except ValueError:

52
tdmpc2/envs/mujoco.py Normal file
View File

@@ -0,0 +1,52 @@
import numpy as np
import gymnasium as gym
from envs.wrappers.timeout import Timeout
MUJOCO_TASKS = {
'mujoco-halfcheetah': 'HalfCheetah-v4',
'lunarlander-continuous': 'LunarLander-v2',
}
class MuJoCoWrapper(gym.Wrapper):
def __init__(self, env, cfg):
super().__init__(env)
self.env = env
self.cfg = cfg
self._cumulative_reward = 0
def reset(self):
self._cumulative_reward = 0
return self.env.reset()[0]
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action.copy())
self._cumulative_reward += reward
done = terminated or truncated
info['terminated'] = terminated
if self.cfg.task == 'lunarlander-continuous':
info['success'] = self._cumulative_reward > 200
return obs, reward, done, info
@property
def unwrapped(self):
return self.env.unwrapped
def render(self, **kwargs):
return self.env.render(**kwargs)
def make_env(cfg):
"""
Make classic/MuJoCo environment.
"""
if not cfg.task in MUJOCO_TASKS:
raise ValueError('Unknown task:', cfg.task)
assert cfg.obs == 'state', 'This task only supports state observations.'
if cfg.task == 'lunarlander-continuous':
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
else:
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
env = MuJoCoWrapper(env, cfg)
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
return env

View File

@@ -24,7 +24,7 @@ class TDMPC2(torch.nn.Module):
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()},
{'params': self.model._reward.parameters()},
{'params': self.model._terminated.parameters()},
{'params': self.model._termination.parameters()},
{'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
}
@@ -36,6 +36,8 @@ class TDMPC2(torch.nn.Module):
self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
print('Episode length:', cfg.episode_length)
print('Discount factor:', self.discount)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile:
print('Compiling update function with torch.compile...')
@@ -122,17 +124,17 @@ class TDMPC2(torch.nn.Module):
def _estimate_value(self, z, actions, task):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
terminated = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task)
G = G + discount * (1-terminated) * reward
G = G + discount * (1-termination) * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update
terminated = torch.clip(terminated + (self.model.terminated(z, task) > 0.5).float(), max=1.)
termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.)
action, _ = self.model.pi(z, task)
return G + discount * (1-terminated) * self.model.Q(z, action, task, return_type='avg')
return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')
@torch.no_grad()
def _plan(self, obs, t0=False, eval_mode=False, task=None):
@@ -278,7 +280,7 @@ class TDMPC2(torch.nn.Module):
_zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task)
terminated_pred = self.model.terminated(zs[-1], task)
termination_pred = self.model.termination(zs[-1], task)
# Compute losses
reward_loss, value_loss = 0, 0
@@ -289,12 +291,12 @@ class TDMPC2(torch.nn.Module):
consistency_loss = consistency_loss / self.cfg.horizon
reward_loss = reward_loss / self.cfg.horizon
terminated_loss = F.binary_cross_entropy(terminated_pred, terminated[-1])
termination_loss = F.binary_cross_entropy(termination_pred, terminated[-1])
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
total_loss = (
self.cfg.consistency_coef * consistency_loss +
self.cfg.reward_coef * reward_loss +
self.cfg.terminated_coef * terminated_loss +
self.cfg.termination_coef * termination_loss +
self.cfg.value_coef * value_loss
)
@@ -316,6 +318,9 @@ class TDMPC2(torch.nn.Module):
"consistency_loss": consistency_loss,
"reward_loss": reward_loss,
"value_loss": value_loss,
"termination_loss": termination_loss,
"termination_mean": termination_pred.mean(),
"termination_mean_gt": terminated[-1].mean(),
"total_loss": total_loss,
"grad_norm": grad_norm,
})

View File

@@ -48,9 +48,8 @@ def train(cfg: dict):
cfg = parse_cfg(cfg)
set_seed(cfg.seed)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \
f'This branch is experimental and only supports cartpole-balance-sparse at this time.'
assert cfg.episodic, \
f'This branch is experimental and only supports episodic RL tasks at this time.'
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
trainer = trainer_cls(

View File

@@ -87,6 +87,8 @@ class OnlineTrainer(Trainer):
train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'],
episode_length=len(self._tds),
episode_terminated=info['terminated'],
)
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')