clean up
This commit is contained in:
@@ -47,8 +47,9 @@ def make_env(cfg):
|
||||
if cfg.task == 'lunarlander-continuous':
|
||||
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
|
||||
else:
|
||||
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') #, terminate_when_unhealthy=False)
|
||||
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
|
||||
env = MuJoCoWrapper(env, cfg)
|
||||
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
|
||||
cfg.discount_max = 0.99 # TODO: temporarily hardcore for these envs, makes comparison to other codebases easier
|
||||
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
|
||||
cfg.rho = 0.7 # TODO: temporarily increase rho for episodic tasks
|
||||
return env
|
||||
|
||||
@@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module):
|
||||
self.discount = torch.tensor(
|
||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||
print('Episode length:', cfg.episode_length)
|
||||
print('Max episode length:', cfg.episode_length)
|
||||
print('Discount factor:', self.discount)
|
||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||
if cfg.compile:
|
||||
@@ -254,7 +254,6 @@ class TDMPC2(torch.nn.Module):
|
||||
"""
|
||||
action, _ = self.model.pi(next_z, task)
|
||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min-all', target=True)
|
||||
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
|
||||
|
||||
def _update(self, obs, action, reward, terminated, task=None):
|
||||
@@ -291,12 +290,7 @@ class TDMPC2(torch.nn.Module):
|
||||
|
||||
consistency_loss = consistency_loss / self.cfg.horizon
|
||||
reward_loss = reward_loss / self.cfg.horizon
|
||||
# termination_loss = F.binary_cross_entropy(termination_pred, terminated)
|
||||
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
|
||||
# termination_loss = F.binary_cross_entropy(termination_pred, terminated, reduction='none')
|
||||
# weighted mean over time, with last time step weighted as much as the rest combined
|
||||
# termination_loss[:-1] = termination_loss[:-1] / (self.cfg.horizon**2)
|
||||
# termination_loss = termination_loss.mean()
|
||||
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss +
|
||||
|
||||
Reference in New Issue
Block a user