add walker2d

This commit is contained in:
Nicklas Hansen
2025-04-09 15:55:57 -07:00
parent 81eb17068e
commit c95b755655
4 changed files with 27 additions and 6 deletions

View File

@@ -184,10 +184,11 @@ class WorldModel(nn.Module):
`return_type` can be one of [`min`, `avg`, `all`]:
- `min`: return the minimum of two randomly subsampled Q-values.
- `avg`: return the average of two randomly subsampled Q-values.
- 'avg-all': return the average of all Q-values.
- `all`: return all Q-values.
`target` specifies whether to use the target Q-networks or not.
"""
assert return_type in {'min', 'avg', 'all'}
assert return_type in {'min', 'avg', 'avg-all', 'all'}
if self.cfg.multitask:
z = self.task_emb(z, task)
@@ -204,6 +205,10 @@ class WorldModel(nn.Module):
if return_type == 'all':
return out
if return_type == 'avg-all':
Q = math.two_hot_inv(out, self.cfg)
return Q.mean(0)
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
Q = math.two_hot_inv(out[qidx], self.cfg)
if return_type == "min":

View File

@@ -16,7 +16,7 @@ steps: 10_000_000
batch_size: 256
reward_coef: 0.1
value_coef: 0.1
termination_coef: 1
termination_coef: 20
consistency_coef: 20
rho: 0.5
lr: 3e-4

View File

@@ -4,6 +4,7 @@ from envs.wrappers.timeout import Timeout
MUJOCO_TASKS = {
'mujoco-walker': 'Walker2d-v4',
'mujoco-halfcheetah': 'HalfCheetah-v4',
'lunarlander-continuous': 'LunarLander-v2',
}
@@ -46,7 +47,8 @@ def make_env(cfg):
if cfg.task == 'lunarlander-continuous':
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
else:
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') #, terminate_when_unhealthy=False)
env = MuJoCoWrapper(env, cfg)
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
cfg.discount_max = 0.99 # TODO: temporarily hardcore for these envs, makes comparison to other codebases easier
return env

View File

@@ -128,7 +128,6 @@ class TDMPC2(torch.nn.Module):
for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task)
G = G + discount * (1-termination) * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update
@@ -255,6 +254,7 @@ class TDMPC2(torch.nn.Module):
"""
action, _ = self.model.pi(next_z, task)
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='avg-all', target=True)
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
def _update(self, obs, action, reward, terminated, task=None):
@@ -314,13 +314,27 @@ class TDMPC2(torch.nn.Module):
# Return training statistics
self.model.eval()
# termination classification metrics
# number of terminations in batch
termination_rate = terminated[-1].sum() / self.cfg.batch_size
# recall = TP / (TP + FN)
termination_tp = ((termination_pred > 0.5) & (terminated[-1] == 1)).sum()
termination_fn = ((termination_pred <= 0.5) & (terminated[-1] == 1)).sum()
termination_fp = ((termination_pred > 0.5) & (terminated[-1] == 0)).sum()
termination_recall = termination_tp / (termination_tp + termination_fn + 1e-9)
# precision = TP / (TP + FP)
termination_precision = termination_tp / (termination_tp + termination_fp + 1e-9)
# F1 score = 2 * (precision * recall) / (precision + recall)
termination_f1 = 2 * (termination_precision * termination_recall) / (termination_precision + termination_recall + 1e-9)
info = TensorDict({
"consistency_loss": consistency_loss,
"reward_loss": reward_loss,
"value_loss": value_loss,
"termination_loss": termination_loss,
"termination_mean": termination_pred.mean(),
"termination_mean_gt": terminated[-1].mean(),
"termination_rate": termination_rate,
"termination_recall": termination_recall,
"termination_precision": termination_precision,
"termination_f1": termination_f1,
"total_loss": total_loss,
"grad_norm": grad_norm,
})