full support for episodic rl
This commit is contained in:
@@ -16,7 +16,7 @@ CONSOLE_FORMAT = [
|
|||||||
("step", "I", "int"),
|
("step", "I", "int"),
|
||||||
("episode_reward", "R", "float"),
|
("episode_reward", "R", "float"),
|
||||||
("episode_success", "S", "float"),
|
("episode_success", "S", "float"),
|
||||||
("total_time", "T", "time"),
|
("elapsed_time", "T", "time"),
|
||||||
]
|
]
|
||||||
|
|
||||||
CAT_TO_COLOR = {
|
CAT_TO_COLOR = {
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from tensordict import TensorDict
|
||||||
|
|
||||||
|
|
||||||
def soft_ce(pred, target, cfg):
|
def soft_ce(pred, target, cfg):
|
||||||
@@ -84,11 +85,26 @@ def two_hot_inv(x, cfg):
|
|||||||
|
|
||||||
|
|
||||||
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
|
def gumbel_softmax_sample(p, temperature=1.0, dim=0):
|
||||||
|
"""Sample from the Gumbel-Softmax distribution."""
|
||||||
logits = p.log()
|
logits = p.log()
|
||||||
# Generate Gumbel noise
|
|
||||||
gumbels = (
|
gumbels = (
|
||||||
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
||||||
) # ~Gumbel(0,1)
|
) # ~Gumbel(0,1)
|
||||||
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
|
gumbels = (logits + gumbels) / temperature # ~Gumbel(logits,tau)
|
||||||
y_soft = gumbels.softmax(dim)
|
y_soft = gumbels.softmax(dim)
|
||||||
return y_soft.argmax(-1)
|
return y_soft.argmax(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def termination_statistics(pred, target, eps=1e-9):
|
||||||
|
"""Compute episode termination statistics."""
|
||||||
|
pred = pred.squeeze(-1)
|
||||||
|
target = target.squeeze(-1)
|
||||||
|
rate = target.sum() / len(target)
|
||||||
|
tp = ((pred > 0.5) & (target == 1)).sum()
|
||||||
|
fn = ((pred <= 0.5) & (target == 1)).sum()
|
||||||
|
fp = ((pred > 0.5) & (target == 0)).sum()
|
||||||
|
recall = tp / (tp + fn + eps)
|
||||||
|
precision = tp / (tp + fp + eps)
|
||||||
|
f1 = 2 * (precision * recall) / (precision + recall + eps)
|
||||||
|
return TensorDict({'termination_rate': rate,
|
||||||
|
'termination_f1': f1})
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ class WorldModel(nn.Module):
|
|||||||
repr = 'TD-MPC2 World Model\n'
|
repr = 'TD-MPC2 World Model\n'
|
||||||
modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
|
modules = ['Encoder', 'Dynamics', 'Reward', 'Termination', 'Policy prior', 'Q-functions']
|
||||||
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
|
for i, m in enumerate([self._encoder, self._dynamics, self._reward, self._termination, self._pi, self._Qs]):
|
||||||
|
if m == self._termination and not self.cfg.episodic:
|
||||||
|
continue
|
||||||
repr += f"{modules[i]}: {m}\n"
|
repr += f"{modules[i]}: {m}\n"
|
||||||
repr += "Learnable parameters: {:,}".format(self.total_params)
|
repr += "Learnable parameters: {:,}".format(self.total_params)
|
||||||
return repr
|
return repr
|
||||||
@@ -127,16 +129,17 @@ class WorldModel(nn.Module):
|
|||||||
z = torch.cat([z, a], dim=-1)
|
z = torch.cat([z, a], dim=-1)
|
||||||
return self._reward(z)
|
return self._reward(z)
|
||||||
|
|
||||||
def termination(self, z, task, sigmoid=True):
|
def termination(self, z, task, unnormalized=False):
|
||||||
"""
|
"""
|
||||||
Predicts termination signal.
|
Predicts termination signal.
|
||||||
"""
|
"""
|
||||||
assert task is None
|
assert task is None
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
z = self.task_emb(z, task)
|
z = self.task_emb(z, task)
|
||||||
if sigmoid:
|
if unnormalized:
|
||||||
return torch.sigmoid(self._termination(z))
|
|
||||||
return self._termination(z)
|
return self._termination(z)
|
||||||
|
return torch.sigmoid(self._termination(z))
|
||||||
|
|
||||||
|
|
||||||
def pi(self, z, task):
|
def pi(self, z, task):
|
||||||
"""
|
"""
|
||||||
@@ -186,12 +189,10 @@ class WorldModel(nn.Module):
|
|||||||
`return_type` can be one of [`min`, `avg`, `all`]:
|
`return_type` can be one of [`min`, `avg`, `all`]:
|
||||||
- `min`: return the minimum of two randomly subsampled Q-values.
|
- `min`: return the minimum of two randomly subsampled Q-values.
|
||||||
- `avg`: return the average of two randomly subsampled Q-values.
|
- `avg`: return the average of two randomly subsampled Q-values.
|
||||||
- 'min-all': return the minimum of all Q-values.
|
|
||||||
- 'avg-all': return the average of all Q-values.
|
|
||||||
- `all`: return all Q-values.
|
- `all`: return all Q-values.
|
||||||
`target` specifies whether to use the target Q-networks or not.
|
`target` specifies whether to use the target Q-networks or not.
|
||||||
"""
|
"""
|
||||||
assert return_type in {'min', 'avg', 'min-all', 'avg-all', 'all'}
|
assert return_type in {'min', 'avg', 'all'}
|
||||||
|
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
z = self.task_emb(z, task)
|
z = self.task_emb(z, task)
|
||||||
@@ -208,14 +209,6 @@ class WorldModel(nn.Module):
|
|||||||
if return_type == 'all':
|
if return_type == 'all':
|
||||||
return out
|
return out
|
||||||
|
|
||||||
if return_type == 'avg-all':
|
|
||||||
Q = math.two_hot_inv(out, self.cfg)
|
|
||||||
return Q.mean(0)
|
|
||||||
|
|
||||||
if return_type == 'min-all':
|
|
||||||
Q = math.two_hot_inv(out, self.cfg)
|
|
||||||
return Q.min(0).values
|
|
||||||
|
|
||||||
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
|
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
|
||||||
Q = math.two_hot_inv(out[qidx], self.cfg)
|
Q = math.two_hot_inv(out[qidx], self.cfg)
|
||||||
if return_type == "min":
|
if return_type == "min":
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ defaults:
|
|||||||
- override hydra/launcher: submitit_local
|
- override hydra/launcher: submitit_local
|
||||||
|
|
||||||
# environment
|
# environment
|
||||||
task: cartpole-balance-sparse
|
task: dog-run
|
||||||
obs: state
|
obs: state
|
||||||
episodic: true
|
episodic: false
|
||||||
|
|
||||||
# evaluation
|
# evaluation
|
||||||
checkpoint: ???
|
checkpoint: ???
|
||||||
|
|||||||
@@ -9,9 +9,8 @@ from dm_control import suite
|
|||||||
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
suite.ALL_TASKS = suite.ALL_TASKS + suite._get_tasks('custom')
|
||||||
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
suite.TASKS_BY_DOMAIN = suite._get_tasks_by_domain(suite.ALL_TASKS)
|
||||||
from dm_control.suite.wrappers import action_scale
|
from dm_control.suite.wrappers import action_scale
|
||||||
from envs.wrappers.timeout import Timeout
|
|
||||||
|
|
||||||
from envs.wrappers.episodic import EpisodicWrapper
|
from envs.wrappers.timeout import Timeout
|
||||||
|
|
||||||
|
|
||||||
def get_obs_shape(env):
|
def get_obs_shape(env):
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from envs.wrappers.timeout import Timeout
|
|||||||
MUJOCO_TASKS = {
|
MUJOCO_TASKS = {
|
||||||
'mujoco-walker': 'Walker2d-v4',
|
'mujoco-walker': 'Walker2d-v4',
|
||||||
'mujoco-halfcheetah': 'HalfCheetah-v4',
|
'mujoco-halfcheetah': 'HalfCheetah-v4',
|
||||||
|
'bipedal-walker': 'BipedalWalker-v3',
|
||||||
'lunarlander-continuous': 'LunarLander-v2',
|
'lunarlander-continuous': 'LunarLander-v2',
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,7 +50,10 @@ def make_env(cfg):
|
|||||||
else:
|
else:
|
||||||
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
|
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
|
||||||
env = MuJoCoWrapper(env, cfg)
|
env = MuJoCoWrapper(env, cfg)
|
||||||
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
|
env = Timeout(env, max_episode_steps={
|
||||||
|
'lunarlander-continuous': 500,
|
||||||
|
'bipedal-walker': 1600,
|
||||||
|
}.get(cfg.task, 1000)) # Default max episode steps for other tasks
|
||||||
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
|
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
|
||||||
cfg.rho = 0.7 # TODO: temporarily increase rho for episodic tasks
|
cfg.rho = 0.7 # TODO: increase rho for episodic tasks since termination always happens at the end of a sequence
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
from collections import deque
|
|
||||||
|
|
||||||
import gymnasium as gym
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodicWrapper(gym.Wrapper):
|
|
||||||
"""
|
|
||||||
Wrapper for testing episodic tasks. Only compatible with cartpole-balance-sparse at the moment.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, cfg, env):
|
|
||||||
super().__init__(env)
|
|
||||||
assert cfg.task == 'cartpole-balance-sparse'
|
|
||||||
self.cfg = cfg
|
|
||||||
self.env = env
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs, reward, done, info = self.env.step(action)
|
|
||||||
if self.cfg.episodic and reward == 0:
|
|
||||||
done = True
|
|
||||||
info['terminated'] = True
|
|
||||||
return obs, reward, done, info
|
|
||||||
@@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
self.discount = torch.tensor(
|
self.discount = torch.tensor(
|
||||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||||
print('Max episode length:', cfg.episode_length)
|
print('Episode length:', cfg.episode_length)
|
||||||
print('Discount factor:', self.discount)
|
print('Discount factor:', self.discount)
|
||||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||||
if cfg.compile:
|
if cfg.compile:
|
||||||
@@ -197,7 +197,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
std = std * self.model._action_masks[task]
|
std = std * self.model._action_masks[task]
|
||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
rand_idx = math.gumbel_softmax_sample(score.squeeze(1))
|
||||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
||||||
a, std = actions[0], std[0]
|
a, std = actions[0], std[0]
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
@@ -279,7 +279,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
_zs = zs[:-1]
|
_zs = zs[:-1]
|
||||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||||
reward_preds = self.model.reward(_zs, action, task)
|
reward_preds = self.model.reward(_zs, action, task)
|
||||||
termination_pred = self.model.termination(zs[1:], task, sigmoid=False)
|
termination_pred = self.model.termination(zs[1:], task, unnormalized=True)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
reward_loss, value_loss = 0, 0
|
reward_loss, value_loss = 0, 0
|
||||||
@@ -290,7 +290,10 @@ class TDMPC2(torch.nn.Module):
|
|||||||
|
|
||||||
consistency_loss = consistency_loss / self.cfg.horizon
|
consistency_loss = consistency_loss / self.cfg.horizon
|
||||||
reward_loss = reward_loss / self.cfg.horizon
|
reward_loss = reward_loss / self.cfg.horizon
|
||||||
|
if self.cfg.episodic:
|
||||||
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
|
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
|
||||||
|
else:
|
||||||
|
termination_loss = 0.
|
||||||
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
||||||
total_loss = (
|
total_loss = (
|
||||||
self.cfg.consistency_coef * consistency_loss +
|
self.cfg.consistency_coef * consistency_loss +
|
||||||
@@ -313,30 +316,16 @@ class TDMPC2(torch.nn.Module):
|
|||||||
|
|
||||||
# Return training statistics
|
# Return training statistics
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
# termination classification metrics
|
|
||||||
# number of terminations in batch
|
|
||||||
termination_rate = terminated[-1].sum() / self.cfg.batch_size
|
|
||||||
# recall = TP / (TP + FN)
|
|
||||||
termination_tp = ((termination_pred > 0.5) & (terminated[-1] == 1)).sum()
|
|
||||||
termination_fn = ((termination_pred <= 0.5) & (terminated[-1] == 1)).sum()
|
|
||||||
termination_fp = ((termination_pred > 0.5) & (terminated[-1] == 0)).sum()
|
|
||||||
termination_recall = termination_tp / (termination_tp + termination_fn + 1e-9)
|
|
||||||
# precision = TP / (TP + FP)
|
|
||||||
termination_precision = termination_tp / (termination_tp + termination_fp + 1e-9)
|
|
||||||
# F1 score = 2 * (precision * recall) / (precision + recall)
|
|
||||||
termination_f1 = 2 * (termination_precision * termination_recall) / (termination_precision + termination_recall + 1e-9)
|
|
||||||
info = TensorDict({
|
info = TensorDict({
|
||||||
"consistency_loss": consistency_loss,
|
"consistency_loss": consistency_loss,
|
||||||
"reward_loss": reward_loss,
|
"reward_loss": reward_loss,
|
||||||
"value_loss": value_loss,
|
"value_loss": value_loss,
|
||||||
"termination_loss": termination_loss,
|
"termination_loss": termination_loss,
|
||||||
"termination_rate": termination_rate,
|
|
||||||
"termination_recall": termination_recall,
|
|
||||||
"termination_precision": termination_precision,
|
|
||||||
"termination_f1": termination_f1,
|
|
||||||
"total_loss": total_loss,
|
"total_loss": total_loss,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
})
|
})
|
||||||
|
if self.cfg.episodic:
|
||||||
|
info.update(math.termination_statistics(torch.sigmoid(termination_pred[-1]), terminated[-1]))
|
||||||
info.update(pi_info)
|
info.update(pi_info)
|
||||||
return info.detach().mean()
|
return info.detach().mean()
|
||||||
|
|
||||||
|
|||||||
@@ -48,8 +48,6 @@ def train(cfg: dict):
|
|||||||
cfg = parse_cfg(cfg)
|
cfg = parse_cfg(cfg)
|
||||||
set_seed(cfg.seed)
|
set_seed(cfg.seed)
|
||||||
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
print(colored('Work dir:', 'yellow', attrs=['bold']), cfg.work_dir)
|
||||||
assert cfg.episodic, \
|
|
||||||
f'This branch is experimental and only supports episodic RL tasks at this time.'
|
|
||||||
|
|
||||||
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
trainer_cls = OfflineTrainer if cfg.multitask else OnlineTrainer
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ class OfflineTrainer(Trainer):
|
|||||||
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
|
if i % self.cfg.eval_freq == 0 or i % 10_000 == 0:
|
||||||
metrics = {
|
metrics = {
|
||||||
'iteration': i,
|
'iteration': i,
|
||||||
'total_time': time() - self._start_time,
|
'elapsed_time': time() - self._start_time,
|
||||||
}
|
}
|
||||||
metrics.update(train_metrics)
|
metrics.update(train_metrics)
|
||||||
if i % self.cfg.eval_freq == 0:
|
if i % self.cfg.eval_freq == 0:
|
||||||
|
|||||||
@@ -17,15 +17,17 @@ class OnlineTrainer(Trainer):
|
|||||||
|
|
||||||
def common_metrics(self):
|
def common_metrics(self):
|
||||||
"""Return a dictionary of current metrics."""
|
"""Return a dictionary of current metrics."""
|
||||||
|
elapsed_time = time() - self._start_time
|
||||||
return dict(
|
return dict(
|
||||||
step=self._step,
|
step=self._step,
|
||||||
episode=self._ep_idx,
|
episode=self._ep_idx,
|
||||||
total_time=time() - self._start_time,
|
elapsed_time=elapsed_time,
|
||||||
|
steps_per_second=self._step / elapsed_time
|
||||||
)
|
)
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
"""Evaluate a TD-MPC2 agent."""
|
"""Evaluate a TD-MPC2 agent."""
|
||||||
ep_rewards, ep_successes = [], []
|
ep_rewards, ep_successes, ep_lengths = [], [], []
|
||||||
for i in range(self.cfg.eval_episodes):
|
for i in range(self.cfg.eval_episodes):
|
||||||
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
|
obs, done, ep_reward, t = self.env.reset(), False, 0, 0
|
||||||
if self.cfg.save_video:
|
if self.cfg.save_video:
|
||||||
@@ -40,11 +42,13 @@ class OnlineTrainer(Trainer):
|
|||||||
self.logger.video.record(self.env)
|
self.logger.video.record(self.env)
|
||||||
ep_rewards.append(ep_reward)
|
ep_rewards.append(ep_reward)
|
||||||
ep_successes.append(info['success'])
|
ep_successes.append(info['success'])
|
||||||
|
ep_lengths.append(t)
|
||||||
if self.cfg.save_video:
|
if self.cfg.save_video:
|
||||||
self.logger.video.save(self._step)
|
self.logger.video.save(self._step)
|
||||||
return dict(
|
return dict(
|
||||||
episode_reward=np.nanmean(ep_rewards),
|
episode_reward=np.nanmean(ep_rewards),
|
||||||
episode_success=np.nanmean(ep_successes),
|
episode_success=np.nanmean(ep_successes),
|
||||||
|
episode_length= np.nanmean(ep_lengths),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_td(self, obs, action=None, reward=None, terminated=None):
|
def to_td(self, obs, action=None, reward=None, terminated=None):
|
||||||
@@ -84,12 +88,14 @@ class OnlineTrainer(Trainer):
|
|||||||
eval_next = False
|
eval_next = False
|
||||||
|
|
||||||
if self._step > 0:
|
if self._step > 0:
|
||||||
|
if info['terminated'] and not self.cfg.episodic:
|
||||||
|
raise ValueError('Termination detected but you are not in episodic mode. ' \
|
||||||
|
'Set `episodic=true` to enable support for terminations.')
|
||||||
train_metrics.update(
|
train_metrics.update(
|
||||||
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
episode_reward=torch.tensor([td['reward'] for td in self._tds[1:]]).sum(),
|
||||||
episode_success=info['success'],
|
episode_success=info['success'],
|
||||||
episode_length=len(self._tds),
|
episode_length=len(self._tds),
|
||||||
episode_terminated=info['terminated'],
|
episode_terminated=info['terminated'])
|
||||||
)
|
|
||||||
train_metrics.update(self.common_metrics())
|
train_metrics.update(self.common_metrics())
|
||||||
self.logger.log(train_metrics, 'train')
|
self.logger.log(train_metrics, 'train')
|
||||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||||
|
|||||||
Reference in New Issue
Block a user