QoL improvements to termination signal debugging
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
name: episodic
|
||||
name: tdmpc2
|
||||
channels:
|
||||
- pytorch-nightly
|
||||
- nvidia
|
||||
@@ -55,3 +55,7 @@ dependencies:
|
||||
# MyoSuite:
|
||||
# - myosuite
|
||||
####################
|
||||
# Classic MuJoCo/Box2d:
|
||||
# - swig
|
||||
# - gymnasium[box2d]
|
||||
####################
|
||||
|
||||
@@ -25,7 +25,7 @@ class WorldModel(nn.Module):
|
||||
self._encoder = layers.enc(cfg)
|
||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||
self._terminated = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
|
||||
self._termination = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
|
||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||
self.apply(init.weight_init)
|
||||
@@ -54,8 +54,8 @@ class WorldModel(nn.Module):
|
||||
|
||||
def __repr__(self):
|
||||
repr = 'TD-MPC2 World Model\n'
|
||||
modules = ['Encoder', 'Dynamics', 'Reward', 'Policy prior', 'Q-functions']
|
||||
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._pi, self._Qs]):
|
||||
modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
|
||||
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
|
||||
repr += f"{modules[i]}: {m}\n"
|
||||
repr += "Learnable parameters: {:,}".format(self.total_params)
|
||||
return repr
|
||||
@@ -127,14 +127,14 @@ class WorldModel(nn.Module):
|
||||
z = torch.cat([z, a], dim=-1)
|
||||
return self._reward(z)
|
||||
|
||||
def terminated(self, z, task):
|
||||
def termination(self, z, task):
|
||||
"""
|
||||
Predicts termination signal.
|
||||
"""
|
||||
assert task is None
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
return torch.sigmoid(self._terminated(z))
|
||||
return torch.sigmoid(self._termination(z))
|
||||
|
||||
def pi(self, z, task):
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,7 @@ steps: 10_000_000
|
||||
batch_size: 256
|
||||
reward_coef: 0.1
|
||||
value_coef: 0.1
|
||||
terminated_coef: 0.1
|
||||
termination_coef: 1
|
||||
consistency_coef: 20
|
||||
rho: 0.5
|
||||
lr: 3e-4
|
||||
@@ -90,4 +90,4 @@ seed_steps: ???
|
||||
bin_size: ???
|
||||
|
||||
# speedups
|
||||
compile: False
|
||||
compile: false
|
||||
|
||||
@@ -9,10 +9,10 @@ from envs.wrappers.tensor import TensorWrapper
|
||||
def missing_dependencies(task):
|
||||
raise ValueError(f'Missing dependencies for task {task}; install dependencies to use this environment.')
|
||||
|
||||
# try:
|
||||
from envs.dmcontrol import make_env as make_dm_control_env
|
||||
# except:
|
||||
# make_dm_control_env = missing_dependencies
|
||||
try:
|
||||
from envs.dmcontrol import make_env as make_dm_control_env
|
||||
except:
|
||||
make_dm_control_env = missing_dependencies
|
||||
try:
|
||||
from envs.maniskill import make_env as make_maniskill_env
|
||||
except:
|
||||
@@ -25,6 +25,10 @@ try:
|
||||
from envs.myosuite import make_env as make_myosuite_env
|
||||
except:
|
||||
make_myosuite_env = missing_dependencies
|
||||
try:
|
||||
from envs.mujoco import make_env as make_mujoco_env
|
||||
except:
|
||||
make_mujoco_env = missing_dependencies
|
||||
|
||||
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
@@ -61,7 +65,7 @@ def make_env(cfg):
|
||||
|
||||
else:
|
||||
env = None
|
||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env]:
|
||||
for fn in [make_dm_control_env, make_maniskill_env, make_metaworld_env, make_myosuite_env, make_mujoco_env]:
|
||||
try:
|
||||
env = fn(cfg)
|
||||
except ValueError:
|
||||
|
||||
52
tdmpc2/envs/mujoco.py
Normal file
52
tdmpc2/envs/mujoco.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
import gymnasium as gym
|
||||
from envs.wrappers.timeout import Timeout
|
||||
|
||||
|
||||
MUJOCO_TASKS = {
|
||||
'mujoco-halfcheetah': 'HalfCheetah-v4',
|
||||
'lunarlander-continuous': 'LunarLander-v2',
|
||||
}
|
||||
|
||||
class MuJoCoWrapper(gym.Wrapper):
|
||||
def __init__(self, env, cfg):
|
||||
super().__init__(env)
|
||||
self.env = env
|
||||
self.cfg = cfg
|
||||
self._cumulative_reward = 0
|
||||
|
||||
def reset(self):
|
||||
self._cumulative_reward = 0
|
||||
return self.env.reset()[0]
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action.copy())
|
||||
self._cumulative_reward += reward
|
||||
done = terminated or truncated
|
||||
info['terminated'] = terminated
|
||||
if self.cfg.task == 'lunarlander-continuous':
|
||||
info['success'] = self._cumulative_reward > 200
|
||||
return obs, reward, done, info
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
return self.env.unwrapped
|
||||
|
||||
def render(self, **kwargs):
|
||||
return self.env.render(**kwargs)
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
"""
|
||||
Make classic/MuJoCo environment.
|
||||
"""
|
||||
if not cfg.task in MUJOCO_TASKS:
|
||||
raise ValueError('Unknown task:', cfg.task)
|
||||
assert cfg.obs == 'state', 'This task only supports state observations.'
|
||||
if cfg.task == 'lunarlander-continuous':
|
||||
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
|
||||
else:
|
||||
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
|
||||
env = MuJoCoWrapper(env, cfg)
|
||||
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
|
||||
return env
|
||||
@@ -24,7 +24,7 @@ class TDMPC2(torch.nn.Module):
|
||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||
{'params': self.model._dynamics.parameters()},
|
||||
{'params': self.model._reward.parameters()},
|
||||
{'params': self.model._terminated.parameters()},
|
||||
{'params': self.model._termination.parameters()},
|
||||
{'params': self.model._Qs.parameters()},
|
||||
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []
|
||||
}
|
||||
@@ -36,6 +36,8 @@ class TDMPC2(torch.nn.Module):
|
||||
self.discount = torch.tensor(
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
print('Episode length:', cfg.episode_length)
|
||||
print('Discount factor:', self.discount)
|
||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||
if cfg.compile:
|
||||
print('Compiling update function with torch.compile...')
|
||||
@@ -122,17 +124,17 @@ class TDMPC2(torch.nn.Module):
|
||||
def _estimate_value(self, z, actions, task):
|
||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||
G, discount = 0, 1
|
||||
terminated = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
|
||||
termination = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
|
||||
for t in range(self.cfg.horizon):
|
||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||
z = self.model.next(z, actions[t], task)
|
||||
|
||||
G = G + discount * (1-terminated) * reward
|
||||
G = G + discount * (1-termination) * reward
|
||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||
discount = discount * discount_update
|
||||
terminated = torch.clip(terminated + (self.model.terminated(z, task) > 0.5).float(), max=1.)
|
||||
termination = torch.clip(termination + (self.model.termination(z, task) > 0.5).float(), max=1.)
|
||||
action, _ = self.model.pi(z, task)
|
||||
return G + discount * (1-terminated) * self.model.Q(z, action, task, return_type='avg')
|
||||
return G + discount * (1-termination) * self.model.Q(z, action, task, return_type='avg')
|
||||
|
||||
@torch.no_grad()
|
||||
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||
@@ -278,7 +280,7 @@ class TDMPC2(torch.nn.Module):
|
||||
_zs = zs[:-1]
|
||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||
reward_preds = self.model.reward(_zs, action, task)
|
||||
terminated_pred = self.model.terminated(zs[-1], task)
|
||||
termination_pred = self.model.termination(zs[-1], task)
|
||||
|
||||
# Compute losses
|
||||
reward_loss, value_loss = 0, 0
|
||||
@@ -289,12 +291,12 @@ class TDMPC2(torch.nn.Module):
|
||||
|
||||
consistency_loss = consistency_loss / self.cfg.horizon
|
||||
reward_loss = reward_loss / self.cfg.horizon
|
||||
terminated_loss = F.binary_cross_entropy(terminated_pred, terminated[-1])
|
||||
termination_loss = F.binary_cross_entropy(termination_pred, terminated[-1])
|
||||
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss +
|
||||
self.cfg.reward_coef * reward_loss +
|
||||
self.cfg.terminated_coef * terminated_loss +
|
||||
self.cfg.termination_coef * termination_loss +
|
||||
self.cfg.value_coef * value_loss
|
||||
)
|
||||
|
||||
@@ -316,6 +318,9 @@ class TDMPC2(torch.nn.Module):
|
||||
"consistency_loss": consistency_loss,
|
||||
"reward_loss": reward_loss,
|
||||
"value_loss": value_loss,
|
||||
"termination_loss": termination_loss,
|
||||
"termination_mean": termination_pred.mean(),
|
||||
"termination_mean_gt": terminated[-1].mean(),
|
||||
"total_loss": total_loss,
|
||||
"grad_norm": grad_norm,
|
||||
})
|
||||
|
||||
@@ -48,9 +48,8 @@ def train(cfg: dict):
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||
|
||||
assert cfg.task == 'cartpole-balance-sparse' and cfg.episodic, \
|
||||
f'This branch is experimental and only supports cartpole-balance-sparse at this time.'
|
||||
assert cfg.episodic, \
|
||||
f'This branch is experimental and only supports episodic RL tasks at this time.'
|
||||
|
||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||
trainer = trainer_cls(
|
||||
|
||||
@@ -87,6 +87,8 @@ class OnlineTrainer(Trainer):
|
||||
train_metrics.update(
|
||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||
episode_success=info['success'],
|
||||
episode_length=len(self._tds),
|
||||
episode_terminated=info['terminated'],
|
||||
)
|
||||
train_metrics.update(self.common_metrics())
|
||||
self.logger.log(train_metrics, 'train')
|
||||
|
||||
Reference in New Issue
Block a user