full support for episodic rl
This commit is contained in:
@@ -16,7 +16,7 @@ CONSOLE_FORMAT = [
|
||||
("step", "I", "int"),
|
||||
("episode_reward", "R", "float"),
|
||||
("episode_success", "S", "float"),
|
||||
("total_time", "T", "time"),
|
||||
("elapsed_time", "T", "time"),
|
||||
]
|
||||
|
||||
CAT_TO_COLOR = {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
||||
def soft_ce(pred, target, cfg):
|
||||
@@ -84,11 +85,26 @@ def two_hot_inv(x, cfg):
|
||||
|
||||
|
||||
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
|
||||
"""Sample from the Gumbel-Softmax distribution."""
|
||||
logits = p.log()
|
||||
# Generate Gumbel noise
|
||||
gumbels = (
|
||||
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
||||
) # ~Gumbel(0,1)
|
||||
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
|
||||
y_soft = gumbels.softmax(dim)
|
||||
return y_soft.argmax(-1)
|
||||
|
||||
|
||||
def termination_statistics(pred, target, eps=1e-9):
|
||||
"""Compute episode termination statistics."""
|
||||
pred = pred.squeeze(-1)
|
||||
target = target.squeeze(-1)
|
||||
rate = target.sum() / len(target)
|
||||
tp = ((pred > 0.5) & (target == 1)).sum()
|
||||
fn = ((pred <= 0.5) & (target == 1)).sum()
|
||||
fp = ((pred > 0.5) & (target == 0)).sum()
|
||||
recall = tp / (tp + fn + eps)
|
||||
precision = tp / (tp + fp + eps)
|
||||
f1 = 2 * (precision * recall) / (precision + recall + eps)
|
||||
return TensorDict({'termination_rate': rate,
|
||||
'termination_f1': f1})
|
||||
|
||||
@@ -56,6 +56,8 @@ class WorldModel(nn.Module):
|
||||
repr = 'TD-MPC2 World Model\n'
|
||||
modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
|
||||
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
|
||||
if m == self._termination and not self.cfg.episodic:
|
||||
continue
|
||||
repr += f"{modules[i]}: {m}\n"
|
||||
repr += "Learnable parameters: {:,}".format(self.total_params)
|
||||
return repr
|
||||
@@ -127,16 +129,17 @@ class WorldModel(nn.Module):
|
||||
z = torch.cat([z, a], dim=-1)
|
||||
return self._reward(z)
|
||||
|
||||
def termination(self, z, task, sigmoid=True):
|
||||
def termination(self, z, task, unnormalized=False):
|
||||
"""
|
||||
Predicts termination signal.
|
||||
"""
|
||||
assert task is None
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
if sigmoid:
|
||||
return torch.sigmoid(self._termination(z))
|
||||
return self._termination(z)
|
||||
if unnormalized:
|
||||
return self._termination(z)
|
||||
return torch.sigmoid(self._termination(z))
|
||||
|
||||
|
||||
def pi(self, z, task):
|
||||
"""
|
||||
@@ -186,12 +189,10 @@ class WorldModel(nn.Module):
|
||||
`return_type` can be one of [`min`, `avg`, `all`]:
|
||||
- `min`: return the minimum of two randomly subsampled Q-values.
|
||||
- `avg`: return the average of two randomly subsampled Q-values.
|
||||
- 'min-all': return the minimum of all Q-values.
|
||||
- 'avg-all': return the average of all Q-values.
|
||||
- `all`: return all Q-values.
|
||||
`target` specifies whether to use the target Q-networks or not.
|
||||
"""
|
||||
assert return_type in {'min', 'avg', 'min-all', 'avg-all', 'all'}
|
||||
assert return_type in {'min', 'avg', 'all'}
|
||||
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
@@ -208,14 +209,6 @@ class WorldModel(nn.Module):
|
||||
if return_type == 'all':
|
||||
return out
|
||||
|
||||
if return_type == 'avg-all':
|
||||
Q = math.two_hot_inv(out, self.cfg)
|
||||
return Q.mean(0)
|
||||
|
||||
if return_type == 'min-all':
|
||||
Q = math.two_hot_inv(out, self.cfg)
|
||||
return Q.min(0).values
|
||||
|
||||
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
|
||||
Q = math.two_hot_inv(out[qidx], self.cfg)
|
||||
if return_type == "min":
|
||||
|
||||
@@ -2,9 +2,9 @@ defaults:
|
||||
- override hydra/launcher: submitit_local
|
||||
|
||||
# environment
|
||||
task: cartpole-balance-sparse
|
||||
task: dog-run
|
||||
obs: state
|
||||
episodic: true
|
||||
episodic: false
|
||||
|
||||
# evaluation
|
||||
checkpoint: ???
|
||||
|
||||
@@ -9,9 +9,8 @@ from dm_control import suite
|
||||
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
||||
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
||||
from dm_control.suite.wrappers import action_scale
|
||||
from envs.wrappers.timeout import Timeout
|
||||
|
||||
from envs.wrappers.episodic import EpisodicWrapper
|
||||
from envs.wrappers.timeout import Timeout
|
||||
|
||||
|
||||
def get_obs_shape(env):
|
||||
|
||||
@@ -6,6 +6,7 @@ from envs.wrappers.timeout import Timeout
|
||||
MUJOCO_TASKS = {
|
||||
'mujoco-walker': 'Walker2d-v4',
|
||||
'mujoco-halfcheetah': 'HalfCheetah-v4',
|
||||
'bipedal-walker': 'BipedalWalker-v3',
|
||||
'lunarlander-continuous': 'LunarLander-v2',
|
||||
}
|
||||
|
||||
@@ -49,7 +50,10 @@ def make_env(cfg):
|
||||
else:
|
||||
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
|
||||
env = MuJoCoWrapper(env, cfg)
|
||||
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
|
||||
env = Timeout(env, max_episode_steps={
|
||||
'lunarlander-continuous': 500,
|
||||
'bipedal-walker': 1600,
|
||||
}.get(cfg.task, 1000)) # Default max episode steps for other tasks
|
||||
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
|
||||
cfg.rho = 0.7 # TODO: temporarily increase rho for episodic tasks
|
||||
cfg.rho = 0.7 # TODO: increase rho for episodic tasks since termination always happens at the end of a sequence
|
||||
return env
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
from collections import deque
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class EpisodicWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env):
|
||||
super().__init__(env)
|
||||
assert cfg.task == 'cartpole-balance-sparse'
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
if self.cfg.episodic and reward == 0:
|
||||
done = True
|
||||
info['terminated'] = True
|
||||
return obs, reward, done, info
|
||||
@@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module):
|
||||
self.discount = torch.tensor(
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
print('Max episode length:', cfg.episode_length)
|
||||
print('Episode length:', cfg.episode_length)
|
||||
print('Discount factor:', self.discount)
|
||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||
if cfg.compile:
|
||||
@@ -197,7 +197,7 @@ class TDMPC2(torch.nn.Module):
|
||||
std = std * self.model._action_masks[task]
|
||||
|
||||
# Select action
|
||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1))
|
||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
||||
a, std = actions[0], std[0]
|
||||
if not eval_mode:
|
||||
@@ -279,7 +279,7 @@ class TDMPC2(torch.nn.Module):
|
||||
_zs = zs[:-1]
|
||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||
reward_preds = self.model.reward(_zs, action, task)
|
||||
termination_pred = self.model.termination(zs[1:], task, sigmoid=False)
|
||||
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
|
||||
|
||||
# Compute losses
|
||||
reward_loss, value_loss = 0, 0
|
||||
@@ -290,7 +290,10 @@ class TDMPC2(torch.nn.Module):
|
||||
|
||||
consistency_loss = consistency_loss / self.cfg.horizon
|
||||
reward_loss = reward_loss / self.cfg.horizon
|
||||
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
|
||||
if self.cfg.episodic:
|
||||
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
|
||||
else:
|
||||
termination_loss = 0.
|
||||
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss +
|
||||
@@ -313,30 +316,16 @@ class TDMPC2(torch.nn.Module):
|
||||
|
||||
# Return training statistics
|
||||
self.model.eval()
|
||||
# termination classification metrics
|
||||
# number of terminations in batch
|
||||
termination_rate = terminated[-1].sum() / self.cfg.batch_size
|
||||
# recall = TP / (TP + FN)
|
||||
termination_tp = ((termination_pred > 0.5) & (terminated[-1] == 1)).sum()
|
||||
termination_fn = ((termination_pred <= 0.5) & (terminated[-1] == 1)).sum()
|
||||
termination_fp = ((termination_pred > 0.5) & (terminated[-1] == 0)).sum()
|
||||
termination_recall = termination_tp / (termination_tp + termination_fn + 1e-9)
|
||||
# precision = TP / (TP + FP)
|
||||
termination_precision = termination_tp / (termination_tp + termination_fp + 1e-9)
|
||||
# F1 score = 2 * (precision * recall) / (precision + recall)
|
||||
termination_f1 = 2 * (termination_precision * termination_recall) / (termination_precision + termination_recall + 1e-9)
|
||||
info = TensorDict({
|
||||
"consistency_loss": consistency_loss,
|
||||
"reward_loss": reward_loss,
|
||||
"value_loss": value_loss,
|
||||
"termination_loss": termination_loss,
|
||||
"termination_rate": termination_rate,
|
||||
"termination_recall": termination_recall,
|
||||
"termination_precision": termination_precision,
|
||||
"termination_f1": termination_f1,
|
||||
"total_loss": total_loss,
|
||||
"grad_norm": grad_norm,
|
||||
})
|
||||
if self.cfg.episodic:
|
||||
info.update(math.termination_statistics(torch.sigmoid(termination_pred[-1]), terminated[-1]))
|
||||
info.update(pi_info)
|
||||
return info.detach().mean()
|
||||
|
||||
|
||||
@@ -48,8 +48,6 @@ def train(cfg: dict):
|
||||
cfg = parse_cfg(cfg)
|
||||
set_seed(cfg.seed)
|
||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||
assert cfg.episodic, \
|
||||
f'This branch is experimental and only supports episodic RL tasks at this time.'
|
||||
|
||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||
trainer = trainer_cls(
|
||||
|
||||
@@ -81,7 +81,7 @@ class OfflineTrainer(Trainer):
|
||||
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
|
||||
metrics = {
|
||||
'iteration': i,
|
||||
'total_time': time() - self._start_time,
|
||||
'elapsed_time': time() - self._start_time,
|
||||
}
|
||||
metrics.update(train_metrics)
|
||||
if i % self.cfg.eval_freq == 0:
|
||||
|
||||
@@ -17,15 +17,17 @@ class OnlineTrainer(Trainer):
|
||||
|
||||
def common_metrics(self):
|
||||
"""Return a dictionary of current metrics."""
|
||||
elapsed_time = time() - self._start_time
|
||||
return dict(
|
||||
step=self._step,
|
||||
episode=self._ep_idx,
|
||||
total_time=time() - self._start_time,
|
||||
elapsed_time=elapsed_time,
|
||||
steps_per_second=self._step / elapsed_time
|
||||
)
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate a TD-MPC2 agent."""
|
||||
ep_rewards, ep_successes = [], []
|
||||
ep_rewards, ep_successes, ep_lengths = [], [], []
|
||||
for i in range(self.cfg.eval_episodes):
|
||||
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
|
||||
if self.cfg.save_video:
|
||||
@@ -40,11 +42,13 @@ class OnlineTrainer(Trainer):
|
||||
self.logger.video.record(self.env)
|
||||
ep_rewards.append(ep_reward)
|
||||
ep_successes.append(info['success'])
|
||||
ep_lengths.append(t)
|
||||
if self.cfg.save_video:
|
||||
self.logger.video.save(self._step)
|
||||
return dict(
|
||||
episode_reward=np.nanmean(ep_rewards),
|
||||
episode_success=np.nanmean(ep_successes),
|
||||
episode_length= np.nanmean(ep_lengths),
|
||||
)
|
||||
|
||||
def to_td(self, obs, action=None, reward=None, terminated=None):
|
||||
@@ -84,12 +88,14 @@ class OnlineTrainer(Trainer):
|
||||
eval_next = False
|
||||
|
||||
if self._step > 0:
|
||||
if info['terminated'] and not self.cfg.episodic:
|
||||
raise ValueError('Termination detected but you are not in episodic mode. ' \
|
||||
'Set `episodic=true` to enable support for terminations.')
|
||||
train_metrics.update(
|
||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||
episode_success=info['success'],
|
||||
episode_length=len(self._tds),
|
||||
episode_terminated=info['terminated'],
|
||||
)
|
||||
episode_terminated=info['terminated'])
|
||||
train_metrics.update(self.common_metrics())
|
||||
self.logger.log(train_metrics, 'train')
|
||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||
|
||||
Reference in New Issue
Block a user