From 38f853efc4aac4873cb535f56df5e02513fdef4d Mon Sep 17 00:00:00 2001 From: Nicklas Hansen Date: Tue, 15 Apr 2025 10:16:02 -0700 Subject: [PATCH] clean up --- tdmpc2/envs/mujoco.py | 5 +++-- tdmpc2/tdmpc2.py | 8 +------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tdmpc2/envs/mujoco.py b/tdmpc2/envs/mujoco.py index a41d4ad..6c5946d 100644 --- a/tdmpc2/envs/mujoco.py +++ b/tdmpc2/envs/mujoco.py @@ -47,8 +47,9 @@ def make_env(cfg): if cfg.task == 'lunarlander-continuous': env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array') else: - env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') #, terminate_when_unhealthy=False) + env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') env = MuJoCoWrapper(env, cfg) env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000) - cfg.discount_max = 0.99 # TODO: temporarily hardcore for these envs, makes comparison to other codebases easier + cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier + cfg.rho = 0.7 # TODO: temporarily increase rho for episodic tasks return env diff --git a/tdmpc2/tdmpc2.py b/tdmpc2/tdmpc2.py index 61de435..80a72cc 100755 --- a/tdmpc2/tdmpc2.py +++ b/tdmpc2/tdmpc2.py @@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module): self.discount = torch.tensor( [self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0' ) if self.cfg.multitask else self._get_discount(cfg.episode_length) - print('Episode length:', cfg.episode_length) + print('Max episode length:', cfg.episode_length) print('Discount factor:', self.discount) self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)) if cfg.compile: @@ -254,7 +254,6 @@ class TDMPC2(torch.nn.Module): """ action, _ = self.model.pi(next_z, task) discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount - # return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min-all', target=True) return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True) def _update(self, obs, action, reward, terminated, task=None): @@ -291,12 +290,7 @@ class TDMPC2(torch.nn.Module): consistency_loss = consistency_loss / self.cfg.horizon reward_loss = reward_loss / self.cfg.horizon - # termination_loss = F.binary_cross_entropy(termination_pred, terminated) termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated) - # termination_loss = F.binary_cross_entropy(termination_pred, terminated, reduction='none') - # weighted mean over time, with last time step weighted as much as the rest combined - # termination_loss[:-1] = termination_loss[:-1] / (self.cfg.horizon**2) - # termination_loss = termination_loss.mean() value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q) total_loss = ( self.cfg.consistency_coef * consistency_loss +