full support for episodic rl

This commit is contained in:
Nicklas Hansen
2025-04-15 15:55:05 -07:00
parent 38f853efc4
commit eece80123d
11 changed files with 55 additions and 74 deletions

View File

@@ -16,7 +16,7 @@ CONSOLE_FORMAT = [
("step", "I", "int"), ("step", "I", "int"),
("episode_reward", "R", "float"), ("episode_reward", "R", "float"),
("episode_success", "S", "float"), ("episode_success", "S", "float"),
("total_time", "T", "time"), ("elapsed_time", "T", "time"),
] ]
CAT_TO_COLOR = { CAT_TO_COLOR = {

View File

@@ -1,5 +1,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from tensordict import TensorDict
def soft_ce(pred, target, cfg): def soft_ce(pred, target, cfg):
@@ -84,11 +85,26 @@ def two_hot_inv(x, cfg):
def gumbel_softmax_sample(p, temperature=1.0, dim=0): def gumbel_softmax_sample(p, temperature=1.0, dim=0):
"""Sample from the Gumbel-Softmax distribution."""
logits = p.log() logits = p.log()
# Generate Gumbel noise
gumbels = ( gumbels = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
) # ~Gumbel(0,1) ) # ~Gumbel(0,1)
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau) gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim) y_soft = gumbels.softmax(dim)
return y_soft.argmax(-1) return y_soft.argmax(-1)
def termination_statistics(pred, target, eps=1e-9):
"""Compute episode termination statistics."""
pred = pred.squeeze(-1)
target = target.squeeze(-1)
rate = target.sum() / len(target)
tp = ((pred > 0.5) & (target == 1)).sum()
fn = ((pred <= 0.5) & (target == 1)).sum()
fp = ((pred > 0.5) & (target == 0)).sum()
recall = tp / (tp + fn + eps)
precision = tp / (tp + fp + eps)
f1 = 2 * (precision * recall) / (precision + recall + eps)
return TensorDict({'termination_rate': rate,
'termination_f1': f1})

View File

@@ -56,6 +56,8 @@ class WorldModel(nn.Module):
repr = 'TD-MPC2 World Model\n' repr = 'TD-MPC2 World Model\n'
modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions'] modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]): for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
if m == self._termination and not self.cfg.episodic:
continue
repr += f"{modules[i]}: {m}\n" repr += f"{modules[i]}: {m}\n"
repr += "Learnable parameters: {:,}".format(self.total_params) repr += "Learnable parameters: {:,}".format(self.total_params)
return repr return repr
@@ -127,16 +129,17 @@ class WorldModel(nn.Module):
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
return self._reward(z) return self._reward(z)
def termination(self, z, task, sigmoid=True): def termination(self, z, task, unnormalized=False):
""" """
Predicts termination signal. Predicts termination signal.
""" """
assert task is None assert task is None
if self.cfg.multitask: if self.cfg.multitask:
z = self.task_emb(z, task) z = self.task_emb(z, task)
if sigmoid: if unnormalized:
return torch.sigmoid(self._termination(z))
return self._termination(z) return self._termination(z)
return torch.sigmoid(self._termination(z))
def pi(self, z, task): def pi(self, z, task):
""" """
@@ -186,12 +189,10 @@ class WorldModel(nn.Module):
`return_type` can be one of [`min`, `avg`, `all`]: `return_type` can be one of [`min`, `avg`, `all`]:
- `min`: return the minimum of two randomly subsampled Q-values. - `min`: return the minimum of two randomly subsampled Q-values.
- `avg`: return the average of two randomly subsampled Q-values. - `avg`: return the average of two randomly subsampled Q-values.
- 'min-all': return the minimum of all Q-values.
- 'avg-all': return the average of all Q-values.
- `all`: return all Q-values. - `all`: return all Q-values.
`target` specifies whether to use the target Q-networks or not. `target` specifies whether to use the target Q-networks or not.
""" """
assert return_type in {'min', 'avg', 'min-all', 'avg-all', 'all'} assert return_type in {'min', 'avg', 'all'}
if self.cfg.multitask: if self.cfg.multitask:
z = self.task_emb(z, task) z = self.task_emb(z, task)
@@ -208,14 +209,6 @@ class WorldModel(nn.Module):
if return_type == 'all': if return_type == 'all':
return out return out
if return_type == 'avg-all':
Q = math.two_hot_inv(out, self.cfg)
return Q.mean(0)
if return_type == 'min-all':
Q = math.two_hot_inv(out, self.cfg)
return Q.min(0).values
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2] qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
Q = math.two_hot_inv(out[qidx], self.cfg) Q = math.two_hot_inv(out[qidx], self.cfg)
if return_type == "min": if return_type == "min":

View File

@@ -2,9 +2,9 @@ defaults:
- override hydra/launcher: submitit_local - override hydra/launcher: submitit_local
# environment # environment
task: cartpole-balance-sparse task: dog-run
obs: state obs: state
episodic: true episodic: false
# evaluation # evaluation
checkpoint: ??? checkpoint: ???

View File

@@ -9,9 +9,8 @@ from dm_control import suite
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom') suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS) suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale from dm_control.suite.wrappers import action_scale
from envs.wrappers.timeout import Timeout
from envs.wrappers.episodic import EpisodicWrapper from envs.wrappers.timeout import Timeout
def get_obs_shape(env): def get_obs_shape(env):

View File

@@ -6,6 +6,7 @@ from envs.wrappers.timeout import Timeout
MUJOCO_TASKS = { MUJOCO_TASKS = {
'mujoco-walker': 'Walker2d-v4', 'mujoco-walker': 'Walker2d-v4',
'mujoco-halfcheetah': 'HalfCheetah-v4', 'mujoco-halfcheetah': 'HalfCheetah-v4',
'bipedal-walker': 'BipedalWalker-v3',
'lunarlander-continuous': 'LunarLander-v2', 'lunarlander-continuous': 'LunarLander-v2',
} }
@@ -49,7 +50,10 @@ def make_env(cfg):
else: else:
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
env = MuJoCoWrapper(env, cfg) env = MuJoCoWrapper(env, cfg)
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000) env = Timeout(env, max_episode_steps={
'lunarlander-continuous': 500,
'bipedal-walker': 1600,
}.get(cfg.task, 1000)) # Default max episode steps for other tasks
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
cfg.rho = 0.7 # TODO: temporarily increase rho for episodic tasks cfg.rho = 0.7 # TODO: increase rho for episodic tasks since termination always happens at the end of a sequence
return env return env

View File

@@ -1,24 +0,0 @@
from collections import deque
import gymnasium as gym
import numpy as np
import torch
class EpisodicWrapper(gym.Wrapper):
"""
Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment.
"""
def __init__(self, cfg, env):
super().__init__(env)
assert cfg.task == 'cartpole-balance-sparse'
self.cfg = cfg
self.env = env
def step(self, action):
obs, reward, done, info = self.env.step(action)
if self.cfg.episodic and reward == 0:
done = True
info['terminated'] = True
return obs, reward, done, info

View File

@@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module):
self.discount = torch.tensor( self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length) ) if self.cfg.multitask else self._get_discount(cfg.episode_length)
print('Max episode length:', cfg.episode_length) print('Episode length:', cfg.episode_length)
print('Discount factor:', self.discount) print('Discount factor:', self.discount)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile: if cfg.compile:
@@ -197,7 +197,7 @@ class TDMPC2(torch.nn.Module):
std = std * self.model._action_masks[task] std = std * self.model._action_masks[task]
# Select action # Select action
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs rand_idx = math.gumbel_softmax_sample(score.squeeze(1))
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1) actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
a, std = actions[0], std[0] a, std = actions[0], std[0]
if not eval_mode: if not eval_mode:
@@ -279,7 +279,7 @@ class TDMPC2(torch.nn.Module):
_zs = zs[:-1] _zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all') qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task) reward_preds = self.model.reward(_zs, action, task)
termination_pred = self.model.termination(zs[1:], task, sigmoid=False) termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
# Compute losses # Compute losses
reward_loss, value_loss = 0, 0 reward_loss, value_loss = 0, 0
@@ -290,7 +290,10 @@ class TDMPC2(torch.nn.Module):
consistency_loss = consistency_loss / self.cfg.horizon consistency_loss = consistency_loss / self.cfg.horizon
reward_loss = reward_loss / self.cfg.horizon reward_loss = reward_loss / self.cfg.horizon
if self.cfg.episodic:
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated) termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
else:
termination_loss = 0.
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q) value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
total_loss = ( total_loss = (
self.cfg.consistency_coef * consistency_loss + self.cfg.consistency_coef * consistency_loss +
@@ -313,30 +316,16 @@ class TDMPC2(torch.nn.Module):
# Return training statistics # Return training statistics
self.model.eval() self.model.eval()
# termination classification metrics
# number of terminations in batch
termination_rate = terminated[-1].sum() / self.cfg.batch_size
# recall = TP / (TP + FN)
termination_tp = ((termination_pred > 0.5) & (terminated[-1] == 1)).sum()
termination_fn = ((termination_pred <= 0.5) & (terminated[-1] == 1)).sum()
termination_fp = ((termination_pred > 0.5) & (terminated[-1] == 0)).sum()
termination_recall = termination_tp / (termination_tp + termination_fn + 1e-9)
# precision = TP / (TP + FP)
termination_precision = termination_tp / (termination_tp + termination_fp + 1e-9)
# F1 score = 2 * (precision * recall) / (precision + recall)
termination_f1 = 2 * (termination_precision * termination_recall) / (termination_precision + termination_recall + 1e-9)
info = TensorDict({ info = TensorDict({
"consistency_loss": consistency_loss, "consistency_loss": consistency_loss,
"reward_loss": reward_loss, "reward_loss": reward_loss,
"value_loss": value_loss, "value_loss": value_loss,
"termination_loss": termination_loss, "termination_loss": termination_loss,
"termination_rate": termination_rate,
"termination_recall": termination_recall,
"termination_precision": termination_precision,
"termination_f1": termination_f1,
"total_loss": total_loss, "total_loss": total_loss,
"grad_norm": grad_norm, "grad_norm": grad_norm,
}) })
if self.cfg.episodic:
info.update(math.termination_statistics(torch.sigmoid(termination_pred[-1]), terminated[-1]))
info.update(pi_info) info.update(pi_info)
return info.detach().mean() return info.detach().mean()

View File

@@ -48,8 +48,6 @@ def train(cfg: dict):
cfg = parse_cfg(cfg) cfg = parse_cfg(cfg)
set_seed(cfg.seed) set_seed(cfg.seed)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir) print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
assert cfg.episodic, \
f'This branch is experimental and only supports episodic RL tasks at this time.'
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
trainer = trainer_cls( trainer = trainer_cls(

View File

@@ -81,7 +81,7 @@ class OfflineTrainer(Trainer):
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0: if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
metrics = { metrics = {
'iteration': i, 'iteration': i,
'total_time': time() - self._start_time, 'elapsed_time': time() - self._start_time,
} }
metrics.update(train_metrics) metrics.update(train_metrics)
if i % self.cfg.eval_freq == 0: if i % self.cfg.eval_freq == 0:

View File

@@ -17,15 +17,17 @@ class OnlineTrainer(Trainer):
def common_metrics(self): def common_metrics(self):
"""Return a dictionary of current metrics.""" """Return a dictionary of current metrics."""
elapsed_time = time() - self._start_time
return dict( return dict(
step=self._step, step=self._step,
episode=self._ep_idx, episode=self._ep_idx,
total_time=time() - self._start_time, elapsed_time=elapsed_time,
steps_per_second=self._step / elapsed_time
) )
def eval(self): def eval(self):
"""Evaluate a TD-MPC2 agent.""" """Evaluate a TD-MPC2 agent."""
ep_rewards, ep_successes = [], [] ep_rewards, ep_successes, ep_lengths = [], [], []
for i in range(self.cfg.eval_episodes): for i in range(self.cfg.eval_episodes):
obs, done, ep_reward, t = self.env.reset(), False, 0, 0 obs, done, ep_reward, t = self.env.reset(), False, 0, 0
if self.cfg.save_video: if self.cfg.save_video:
@@ -40,11 +42,13 @@ class OnlineTrainer(Trainer):
self.logger.video.record(self.env) self.logger.video.record(self.env)
ep_rewards.append(ep_reward) ep_rewards.append(ep_reward)
ep_successes.append(info['success']) ep_successes.append(info['success'])
ep_lengths.append(t)
if self.cfg.save_video: if self.cfg.save_video:
self.logger.video.save(self._step) self.logger.video.save(self._step)
return dict( return dict(
episode_reward=np.nanmean(ep_rewards), episode_reward=np.nanmean(ep_rewards),
episode_success=np.nanmean(ep_successes), episode_success=np.nanmean(ep_successes),
episode_length= np.nanmean(ep_lengths),
) )
def to_td(self, obs, action=None, reward=None, terminated=None): def to_td(self, obs, action=None, reward=None, terminated=None):
@@ -84,12 +88,14 @@ class OnlineTrainer(Trainer):
eval_next = False eval_next = False
if self._step > 0: if self._step > 0:
if info['terminated'] and not self.cfg.episodic:
raise ValueError('Termination detected but you are not in episodic mode. ' \
'Set `episodic=true` to enable support for terminations.')
train_metrics.update( train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(), episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'], episode_success=info['success'],
episode_length=len(self._tds), episode_length=len(self._tds),
episode_terminated=info['terminated'], episode_terminated=info['terminated'])
)
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx = self.buffer.add(torch.cat(self._tds))