This commit is contained in:
Nicklas Hansen
2025-04-15 10:16:02 -07:00
parent 62be41ab58
commit 38f853efc4
2 changed files with 4 additions and 9 deletions

View File

@@ -47,8 +47,9 @@ def make_env(cfg):
if cfg.task == 'lunarlander-continuous':
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
else:
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') #, terminate_when_unhealthy=False)
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
env = MuJoCoWrapper(env, cfg)
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
cfg.discount_max = 0.99 # TODO: temporarily hardcore for these envs, makes comparison to other codebases easier
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
cfg.rho = 0.7 # TODO: temporarily increase rho for episodic tasks
return env

View File

@@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module):
self.discount = torch.tensor(
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
print('Episode length:', cfg.episode_length)
print('Max episode length:', cfg.episode_length)
print('Discount factor:', self.discount)
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
if cfg.compile:
@@ -254,7 +254,6 @@ class TDMPC2(torch.nn.Module):
"""
action, _ = self.model.pi(next_z, task)
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min-all', target=True)
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
def _update(self, obs, action, reward, terminated, task=None):
@@ -291,12 +290,7 @@ class TDMPC2(torch.nn.Module):
consistency_loss = consistency_loss / self.cfg.horizon
reward_loss = reward_loss / self.cfg.horizon
# termination_loss = F.binary_cross_entropy(termination_pred, terminated)
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
# termination_loss = F.binary_cross_entropy(termination_pred, terminated, reduction='none')
# weighted mean over time, with last time step weighted as much as the rest combined
# termination_loss[:-1] = termination_loss[:-1] / (self.cfg.horizon**2)
# termination_loss = termination_loss.mean()
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
total_loss = (
self.cfg.consistency_coef * consistency_loss +