clean up
This commit is contained in:
@@ -47,8 +47,9 @@ def make_env(cfg):
|
|||||||
if cfg.task == 'lunarlander-continuous':
|
if cfg.task == 'lunarlander-continuous':
|
||||||
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
|
env = gym.make(MUJOCO_TASKS[cfg.task], continuous=True, render_mode='rgb_array')
|
||||||
else:
|
else:
|
||||||
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array') #, terminate_when_unhealthy=False)
|
env = gym.make(MUJOCO_TASKS[cfg.task], render_mode='rgb_array')
|
||||||
env = MuJoCoWrapper(env, cfg)
|
env = MuJoCoWrapper(env, cfg)
|
||||||
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
|
env = Timeout(env, max_episode_steps=500 if cfg.task.startswith('lunarlander') else 1000)
|
||||||
cfg.discount_max = 0.99 # TODO: temporarily hardcore for these envs, makes comparison to other codebases easier
|
cfg.discount_max = 0.99 # TODO: temporarily hardcode for these envs, makes comparison to other codebases easier
|
||||||
|
cfg.rho = 0.7 # TODO: temporarily increase rho for episodic tasks
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
self.discount = torch.tensor(
|
self.discount = torch.tensor(
|
||||||
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
[self._get_discount(ep_len) for ep_len in cfg.episode_lengths], device='cuda:0'
|
||||||
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
) if self.cfg.multitask else self._get_discount(cfg.episode_length)
|
||||||
print('Episode length:', cfg.episode_length)
|
print('Max episode length:', cfg.episode_length)
|
||||||
print('Discount factor:', self.discount)
|
print('Discount factor:', self.discount)
|
||||||
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
self._prev_mean = torch.nn.Buffer(torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device))
|
||||||
if cfg.compile:
|
if cfg.compile:
|
||||||
@@ -254,7 +254,6 @@ class TDMPC2(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
action, _ = self.model.pi(next_z, task)
|
action, _ = self.model.pi(next_z, task)
|
||||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||||
# return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min-all', target=True)
|
|
||||||
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
|
return reward + discount * (1-terminated) * self.model.Q(next_z, action, task, return_type='min', target=True)
|
||||||
|
|
||||||
def _update(self, obs, action, reward, terminated, task=None):
|
def _update(self, obs, action, reward, terminated, task=None):
|
||||||
@@ -291,12 +290,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
|
|
||||||
consistency_loss = consistency_loss / self.cfg.horizon
|
consistency_loss = consistency_loss / self.cfg.horizon
|
||||||
reward_loss = reward_loss / self.cfg.horizon
|
reward_loss = reward_loss / self.cfg.horizon
|
||||||
# termination_loss = F.binary_cross_entropy(termination_pred, terminated)
|
|
||||||
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
|
termination_loss = F.binary_cross_entropy_with_logits(termination_pred, terminated)
|
||||||
# termination_loss = F.binary_cross_entropy(termination_pred, terminated, reduction='none')
|
|
||||||
# weighted mean over time, with last time step weighted as much as the rest combined
|
|
||||||
# termination_loss[:-1] = termination_loss[:-1] / (self.cfg.horizon**2)
|
|
||||||
# termination_loss = termination_loss.mean()
|
|
||||||
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
value_loss = value_loss / (self.cfg.horizon * self.cfg.num_q)
|
||||||
total_loss = (
|
total_loss = (
|
||||||
self.cfg.consistency_coef * consistency_loss +
|
self.cfg.consistency_coef * consistency_loss +
|
||||||
|
|||||||
Reference in New Issue
Block a user