full support for episodic rl

This commit is contained in:
Nicklas Hansen
2025-04-15 15:55:05 -07:00
parent 38f853efc4
commit eece80123d
11 changed files with 55 additions and 74 deletions

View File

@@ -16,7 +16,7 @@ CONSOLE_FORMAT = [
("step", "I", "int"),
("episode_reward", "R", "float"),
("episode_success", "S", "float"),
("total_time", "T", "time"),
("elapsed_time", "T", "time"),
]
CAT_TO_COLOR = {

View File

@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from tensordict import TensorDict
def soft_ce(pred, target, cfg):
@@ -84,11 +85,26 @@ def two_hot_inv(x, cfg):
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
"""Sample from the Gumbel-Softmax distribution."""
logits = p.log()
# Generate Gumbel noise
gumbels = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
) # ~Gumbel(0,1)
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
return y_soft.argmax(-1)
def termination_statistics(pred, target, eps=1e-9):
"""Compute episode termination statistics."""
pred = pred.squeeze(-1)
target = target.squeeze(-1)
rate = target.sum() / len(target)
tp = ((pred > 0.5) & (target == 1)).sum()
fn = ((pred <= 0.5) & (target == 1)).sum()
fp = ((pred > 0.5) & (target == 0)).sum()
recall = tp / (tp + fn + eps)
precision = tp / (tp + fp + eps)
f1 = 2 * (precision * recall) / (precision + recall + eps)
return TensorDict({'termination_rate': rate,
'termination_f1': f1})

View File

@@ -56,6 +56,8 @@ class WorldModel(nn.Module):
repr = 'TD-MPC2 World Model\n'
modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
if m == self._termination and not self.cfg.episodic:
continue
repr += f"{modules[i]}: {m}\n"
repr += "Learnable parameters: {:,}".format(self.total_params)
return repr
@@ -127,16 +129,17 @@ class WorldModel(nn.Module):
z = torch.cat([z, a], dim=-1)
return self._reward(z)
def termination(self, z, task, sigmoid=True):
def termination(self, z, task, unnormalized=False):
"""
Predicts termination signal.
"""
assert task is None
if self.cfg.multitask:
z = self.task_emb(z, task)
if sigmoid:
return torch.sigmoid(self._termination(z))
return self._termination(z)
if unnormalized:
return self._termination(z)
return torch.sigmoid(self._termination(z))
def pi(self, z, task):
"""
@@ -186,12 +189,10 @@ class WorldModel(nn.Module):
`return_type` can be one of [`min`, `avg`, `all`]:
- `min`: return the minimum of two randomly subsampled Q-values.
- `avg`: return the average of two randomly subsampled Q-values.
- 'min-all': return the minimum of all Q-values.
- 'avg-all': return the average of all Q-values.
- `all`: return all Q-values.
`target` specifies whether to use the target Q-networks or not.
"""
assert return_type in {'min', 'avg', 'min-all', 'avg-all', 'all'}
assert return_type in {'min', 'avg', 'all'}
if self.cfg.multitask:
z = self.task_emb(z, task)
@@ -208,14 +209,6 @@ class WorldModel(nn.Module):
if return_type == 'all':
return out
if return_type == 'avg-all':
Q = math.two_hot_inv(out, self.cfg)
return Q.mean(0)
if return_type == 'min-all':
Q = math.two_hot_inv(out, self.cfg)
return Q.min(0).values
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
Q = math.two_hot_inv(out[qidx], self.cfg)
if return_type == "min":

View File

@@ -2,9 +2,9 @@ defaults:
- override hydra/launcher: submitit_local
# environment
task: cartpole-balance-sparse
task: dog-run
obs: state
episodic: true
episodic: false
# evaluation
checkpoint: ???

View File

@@ -9,9 +9,8 @@ from dm_control import suite
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
from dm_control.suite.wrappers import action_scale
from envs.wrappers.timeout import Timeout
from envs.wrappers.episodic import EpisodicWrapper
from envs.wrappers.timeout import Timeout
def get_obs_shape(env):

View File

@@ -6,6 +6,7 @@ from envs.wrappers.timeout import Timeout
MUJOCO_TASKS = {
'mujoco-walker': 'Walker2d-v4',
'mujoco-halfcheetah': 'HalfCheetah-v4',
'bipedal-walker': 'BipedalWalker-v3',
'lunarlander-continuous': 'LunarLander-v2',
}
@@ -49,7 +50,10 @@ def make_env(cfg):
else:
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
env = MuJoCoWrapper(env, cfg)
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
env = Timeout(env, max_episode_steps={
'lunarlander-continuous': 500,
'bipedal-walker': 1600,
}.get(cfg.task, 1000)) # Default max episode steps for other tasks
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
cfg.rho = 0.7 # TODO: temporarily increase rho for episodic tasks
cfg.rho = 0.7 # TODO: increase rho for episodic tasks since termination always happens at the end of a sequence
return env

View File

@@ -1,24 +0,0 @@
from collections import deque
import gymnasium as gym
import numpy as np
import torch
class EpisodicWrapper(gym.Wrapper):
"""
Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment.
"""
def __init__(self, cfg, env):
super().__init__(env)
assert cfg.task == 'cartpole-balance-sparse'
self.cfg = cfg
self.env = env
def step(self, action):
obs, reward, done, info = self.env.step(action)
if self.cfg.episodic and reward == 0:
done = True
info['terminated'] = True
return obs, reward, done, info

View File

@@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module):
self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
print('Max episode length:', cfg.episode_length)
print('Episode length:', cfg.episode_length)
print('Discount factor:', self.discount)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile:
@@ -197,7 +197,7 @@ class TDMPC2(torch.nn.Module):
std = std * self.model._action_masks[task]
# Select action
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
rand_idx = math.gumbel_softmax_sample(score.squeeze(1))
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
a, std = actions[0], std[0]
if not eval_mode:
@@ -279,7 +279,7 @@ class TDMPC2(torch.nn.Module):
_zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task)
termination_pred = self.model.termination(zs[1:], task, sigmoid=False)
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
# Compute losses
reward_loss, value_loss = 0, 0
@@ -290,7 +290,10 @@ class TDMPC2(torch.nn.Module):
consistency_loss = consistency_loss / self.cfg.horizon
reward_loss = reward_loss / self.cfg.horizon
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
if self.cfg.episodic:
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
else:
termination_loss = 0.
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
total_loss = (
self.cfg.consistency_coef * consistency_loss +
@@ -313,30 +316,16 @@ class TDMPC2(torch.nn.Module):
# Return training statistics
self.model.eval()
# termination classification metrics
# number of terminations in batch
termination_rate = terminated[-1].sum() / self.cfg.batch_size
# recall = TP / (TP + FN)
termination_tp = ((termination_pred > 0.5) & (terminated[-1] == 1)).sum()
termination_fn = ((termination_pred <= 0.5) & (terminated[-1] == 1)).sum()
termination_fp = ((termination_pred > 0.5) & (terminated[-1] == 0)).sum()
termination_recall = termination_tp / (termination_tp + termination_fn + 1e-9)
# precision = TP / (TP + FP)
termination_precision = termination_tp / (termination_tp + termination_fp + 1e-9)
# F1 score = 2 * (precision * recall) / (precision + recall)
termination_f1 = 2 * (termination_precision * termination_recall) / (termination_precision + termination_recall + 1e-9)
info = TensorDict({
"consistency_loss": consistency_loss,
"reward_loss": reward_loss,
"value_loss": value_loss,
"termination_loss": termination_loss,
"termination_rate": termination_rate,
"termination_recall": termination_recall,
"termination_precision": termination_precision,
"termination_f1": termination_f1,
"total_loss": total_loss,
"grad_norm": grad_norm,
})
if self.cfg.episodic:
info.update(math.termination_statistics(torch.sigmoid(termination_pred[-1]), terminated[-1]))
info.update(pi_info)
return info.detach().mean()

View File

@@ -48,8 +48,6 @@ def train(cfg: dict):
cfg = parse_cfg(cfg)
set_seed(cfg.seed)
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
assert cfg.episodic, \
f'This branch is experimental and only supports episodic RL tasks at this time.'
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
trainer = trainer_cls(

View File

@@ -81,7 +81,7 @@ class OfflineTrainer(Trainer):
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
metrics = {
'iteration': i,
'total_time': time() - self._start_time,
'elapsed_time': time() - self._start_time,
}
metrics.update(train_metrics)
if i % self.cfg.eval_freq == 0:

View File

@@ -17,15 +17,17 @@ class OnlineTrainer(Trainer):
def common_metrics(self):
"""Return a dictionary of current metrics."""
elapsed_time = time() - self._start_time
return dict(
step=self._step,
episode=self._ep_idx,
total_time=time() - self._start_time,
elapsed_time=elapsed_time,
steps_per_second=self._step / elapsed_time
)
def eval(self):
"""Evaluate a TD-MPC2 agent."""
ep_rewards, ep_successes = [], []
ep_rewards, ep_successes, ep_lengths = [], [], []
for i in range(self.cfg.eval_episodes):
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
if self.cfg.save_video:
@@ -40,11 +42,13 @@ class OnlineTrainer(Trainer):
self.logger.video.record(self.env)
ep_rewards.append(ep_reward)
ep_successes.append(info['success'])
ep_lengths.append(t)
if self.cfg.save_video:
self.logger.video.save(self._step)
return dict(
episode_reward=np.nanmean(ep_rewards),
episode_success=np.nanmean(ep_successes),
episode_length= np.nanmean(ep_lengths),
)
def to_td(self, obs, action=None, reward=None, terminated=None):
@@ -84,12 +88,14 @@ class OnlineTrainer(Trainer):
eval_next = False
if self._step > 0:
if info['terminated'] and not self.cfg.episodic:
raise ValueError('Termination detected but you are not in episodic mode. ' \
'Set `episodic=true` to enable support for terminations.')
train_metrics.update(
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
episode_success=info['success'],
episode_length=len(self._tds),
episode_terminated=info['terminated'],
)
episode_terminated=info['terminated'])
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')
self._ep_idx = self.buffer.add(torch.cat(self._tds))