add walker2d
This commit is contained in:
@@ -184,10 +184,11 @@ class WorldModel(nn.Module):
|
||||
`return_type` can be one of [`min`, `avg`, `all`]:
|
||||
- `min`: return the minimum of two randomly subsampled Q-values.
|
||||
- `avg`: return the average of two randomly subsampled Q-values.
|
||||
- 'avg-all': return the average of all Q-values.
|
||||
- `all`: return all Q-values.
|
||||
`target` specifies whether to use the target Q-networks or not.
|
||||
"""
|
||||
assert return_type in {'min', 'avg', 'all'}
|
||||
assert return_type in {'min', 'avg', 'avg-all', 'all'}
|
||||
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
@@ -204,6 +205,10 @@ class WorldModel(nn.Module):
|
||||
if return_type == 'all':
|
||||
return out
|
||||
|
||||
if return_type == 'avg-all':
|
||||
Q = math.two_hot_inv(out, self.cfg)
|
||||
return Q.mean(0)
|
||||
|
||||
qidx = torch.randperm(self.cfg.num_q, device=out.device)[:2]
|
||||
Q = math.two_hot_inv(out[qidx], self.cfg)
|
||||
if return_type == "min":
|
||||
|
||||
@@ -16,7 +16,7 @@ steps: 10_000_000
|
||||
batch_size: 256
|
||||
reward_coef: 0.1
|
||||
value_coef: 0.1
|
||||
termination_coef: 1
|
||||
termination_coef: 20
|
||||
consistency_coef: 20
|
||||
rho: 0.5
|
||||
lr: 3e-4
|
||||
|
||||
@@ -4,6 +4,7 @@ from envs.wrappers.timeout import Timeout
|
||||
|
||||
|
||||
MUJOCO_TASKS = {
|
||||
'mujoco-walker': 'Walker2d-v4',
|
||||
'mujoco-halfcheetah': 'HalfCheetah-v4',
|
||||
'lunarlander-continuous': 'LunarLander-v2',
|
||||
}
|
||||
@@ -46,7 +47,8 @@ def make_env(cfg):
|
||||
if cfg.task == 'lunarlander-continuous':
|
||||
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
|
||||
else:
|
||||
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
|
||||
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') #, terminate_when_unhealthy=False)
|
||||
env = MuJoCoWrapper(env, cfg)
|
||||
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
|
||||
cfg.discount_max = 0.99 # TODO: temporarily hardcore for these envs, makes comparison to other codebases easier
|
||||
return env
|
||||
|
||||
@@ -128,7 +128,6 @@ class TDMPC2(torch.nn.Module):
|
||||
for t in range(self.cfg.horizon):
|
||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||
z = self.model.next(z, actions[t], task)
|
||||
|
||||
G = G + discount * (1-termination) * reward
|
||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||
discount = discount * discount_update
|
||||
@@ -255,6 +254,7 @@ class TDMPC2(torch.nn.Module):
|
||||
"""
|
||||
action, _ = self.model.pi(next_z, task)
|
||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='avg-all', target=True)
|
||||
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
|
||||
|
||||
def _update(self, obs, action, reward, terminated, task=None):
|
||||
@@ -314,13 +314,27 @@ class TDMPC2(torch.nn.Module):
|
||||
|
||||
# Return training statistics
|
||||
self.model.eval()
|
||||
# termination classification metrics
|
||||
# number of terminations in batch
|
||||
termination_rate = terminated[-1].sum() / self.cfg.batch_size
|
||||
# recall = TP / (TP + FN)
|
||||
termination_tp = ((termination_pred > 0.5) & (terminated[-1] == 1)).sum()
|
||||
termination_fn = ((termination_pred <= 0.5) & (terminated[-1] == 1)).sum()
|
||||
termination_fp = ((termination_pred > 0.5) & (terminated[-1] == 0)).sum()
|
||||
termination_recall = termination_tp / (termination_tp + termination_fn + 1e-9)
|
||||
# precision = TP / (TP + FP)
|
||||
termination_precision = termination_tp / (termination_tp + termination_fp + 1e-9)
|
||||
# F1 score = 2 * (precision * recall) / (precision + recall)
|
||||
termination_f1 = 2 * (termination_precision * termination_recall) / (termination_precision + termination_recall + 1e-9)
|
||||
info = TensorDict({
|
||||
"consistency_loss": consistency_loss,
|
||||
"reward_loss": reward_loss,
|
||||
"value_loss": value_loss,
|
||||
"termination_loss": termination_loss,
|
||||
"termination_mean": termination_pred.mean(),
|
||||
"termination_mean_gt": terminated[-1].mean(),
|
||||
"termination_rate": termination_rate,
|
||||
"termination_recall": termination_recall,
|
||||
"termination_precision": termination_precision,
|
||||
"termination_f1": termination_f1,
|
||||
"total_loss": total_loss,
|
||||
"grad_norm": grad_norm,
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user