This commit is contained in:
Nicklas Hansen
2024-01-07 18:16:33 -08:00
parent 33876d124f
commit 26c72119cd
7 changed files with 44 additions and 13 deletions

View File

@@ -78,8 +78,9 @@ class Buffer():
obs = td['obs'] obs = td['obs']
action = td['action'][1:] action = td['action'][1:]
reward = td['reward'][1:].unsqueeze(-1) reward = td['reward'][1:].unsqueeze(-1)
terminated = td['terminated'][1:].unsqueeze(-1)
task = td['task'][0] if 'task' in td.keys() else None task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, task) return self._to_device(obs, action, reward, terminated, task)
def add(self, td): def add(self, td):
"""Add an episode to the buffer.""" """Add an episode to the buffer."""

View File

@@ -24,6 +24,7 @@ class WorldModel(nn.Module):
self._encoder = layers.enc(cfg) self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg)) self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1)) self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._terminated = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim) self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)]) self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init) self.apply(init.weight_init)
@@ -119,6 +120,15 @@ class WorldModel(nn.Module):
z = torch.cat([z, a], dim=-1) z = torch.cat([z, a], dim=-1)
return self._reward(z) return self._reward(z)
def terminated(self, z, task):
"""
Predicts termination signal.
"""
assert task is None
if self.cfg.multitask:
z = self.task_emb(z, task)
return torch.sigmoid(self._terminated(z))
def pi(self, z, task): def pi(self, z, task):
""" """
Samples an action from the policy prior. Samples an action from the policy prior.

View File

@@ -15,6 +15,7 @@ steps: 10_000_000
batch_size: 256 batch_size: 256
reward_coef: 0.1 reward_coef: 0.1
value_coef: 0.1 value_coef: 0.1
terminated_coef: 0.1
consistency_coef: 20 consistency_coef: 20
rho: 0.5 rho: 0.5
lr: 3e-4 lr: 3e-4

View File

@@ -47,9 +47,12 @@ class ManiSkillWrapper(gym.Wrapper):
def step(self, action): def step(self, action):
reward = 0 reward = 0
for _ in range(2): for _ in range(2):
obs, r, _, info = self.env.step(action) obs, r, done, info = self.env.step(action)
reward += r reward += r
return obs, reward, False, info info['terminated'] = done
if done:
break
return obs, reward, done, info
@property @property
def unwrapped(self): def unwrapped(self):

View File

@@ -37,4 +37,5 @@ class TensorWrapper(gym.Wrapper):
obs, reward, done, info = self.env.step(action.numpy()) obs, reward, done, info = self.env.step(action.numpy())
info = defaultdict(float, info) info = defaultdict(float, info)
info['success'] = float(info['success']) info['success'] = float(info['success'])
info['terminated'] = torch.tensor(float(info['terminated']))
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info

View File

@@ -22,6 +22,7 @@ class TDMPC2:
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale}, {'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()}, {'params': self.model._dynamics.parameters()},
{'params': self.model._reward.parameters()}, {'params': self.model._reward.parameters()},
{'params': self.model._terminated.parameters()},
{'params': self.model._Qs.parameters()}, {'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []} {'params': self.model._task_emb.parameters() if self.cfg.multitask else []}
], lr=self.cfg.lr) ], lr=self.cfg.lr)
@@ -95,12 +96,14 @@ class TDMPC2:
def _estimate_value(self, z, actions, task): def _estimate_value(self, z, actions, task):
"""Estimate value of a trajectory starting at latent state z and executing given actions.""" """Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1 G, discount = 0, 1
terminated = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg) reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task) z = self.model.next(z, actions[t], task)
G += discount * reward G += discount * (1-terminated) * reward
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg') terminated = torch.clip_(terminated + (self.model.terminated(z, task) > 0.5).float(), max=1.)
return G + discount * (1-terminated) * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
@torch.no_grad() @torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None): def plan(self, z, t0=False, eval_mode=False, task=None):
@@ -199,13 +202,14 @@ class TDMPC2:
return pi_loss.item() return pi_loss.item()
@torch.no_grad() @torch.no_grad()
def _td_target(self, next_z, reward, task): def _td_target(self, next_z, reward, terminated, task):
""" """
Compute the TD-target from a reward and the observation at the following time step. Compute the TD-target from a reward and the observation at the following time step.
Args: Args:
next_z (torch.Tensor): Latent state at the following time step. next_z (torch.Tensor): Latent state at the following time step.
reward (torch.Tensor): Reward at the current time step. reward (torch.Tensor): Reward at the current time step.
terminated (torch.Tensor): Termination signal at the current time step.
task (torch.Tensor): Task index (only used for multi-task experiments). task (torch.Tensor): Task index (only used for multi-task experiments).
Returns: Returns:
@@ -213,7 +217,7 @@ class TDMPC2:
""" """
pi = self.model.pi(next_z, task)[1] pi = self.model.pi(next_z, task)[1]
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True) return reward + discount * (1-terminated) * self.model.Q(next_z, pi, task, return_type='min', target=True)
def update(self, buffer): def update(self, buffer):
""" """
@@ -225,12 +229,12 @@ class TDMPC2:
Returns: Returns:
dict: Dictionary of training statistics. dict: Dictionary of training statistics.
""" """
obs, action, reward, task = buffer.sample() obs, action, reward, terminated, task = buffer.sample()
# Compute targets # Compute targets
with torch.no_grad(): with torch.no_grad():
next_z = self.model.encode(obs[1:], task) next_z = self.model.encode(obs[1:], task)
td_targets = self._td_target(next_z, reward, task) td_targets = self._td_target(next_z, reward, terminated, task)
# Prepare for update # Prepare for update
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
@@ -250,19 +254,23 @@ class TDMPC2:
_zs = zs[:-1] _zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all') qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task) reward_preds = self.model.reward(_zs, action, task)
terminated_preds = self.model.terminated(_zs, task)
# Compute losses # Compute losses
reward_loss, value_loss = 0, 0 reward_loss, terminated_loss, value_loss = 0, 0, 0
for t in range(self.cfg.horizon): for t in range(self.cfg.horizon):
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t
for q in range(self.cfg.num_q): for q in range(self.cfg.num_q):
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
consistency_loss *= (1/self.cfg.horizon) consistency_loss *= (1/self.cfg.horizon)
reward_loss *= (1/self.cfg.horizon) reward_loss *= (1/self.cfg.horizon)
terminated_loss *= (1/self.cfg.horizon)
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q)) value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
total_loss = ( total_loss = (
self.cfg.consistency_coef * consistency_loss + self.cfg.consistency_coef * consistency_loss +
self.cfg.reward_coef * reward_loss + self.cfg.reward_coef * reward_loss +
self.cfg.terminated_coef * terminated_loss +
self.cfg.value_coef * value_loss self.cfg.value_coef * value_loss
) )
@@ -282,6 +290,7 @@ class TDMPC2:
return { return {
"consistency_loss": float(consistency_loss.mean().item()), "consistency_loss": float(consistency_loss.mean().item()),
"reward_loss": float(reward_loss.mean().item()), "reward_loss": float(reward_loss.mean().item()),
"terminated_loss": float(terminated_loss.mean().item()),
"value_loss": float(value_loss.mean().item()), "value_loss": float(value_loss.mean().item()),
"pi_loss": pi_loss, "pi_loss": pi_loss,
"total_loss": float(total_loss.mean().item()), "total_loss": float(total_loss.mean().item()),

View File

@@ -47,8 +47,10 @@ class OnlineTrainer(Trainer):
episode_success=np.nanmean(ep_successes), episode_success=np.nanmean(ep_successes),
) )
def to_td(self, obs, action=None, reward=None): def to_td(self, obs=None, action=None, reward=None, terminated=None):
"""Creates a TensorDict for a new episode.""" """Creates a TensorDict for a new episode."""
if obs is None:
obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan'))
if isinstance(obs, dict): if isinstance(obs, dict):
obs = TensorDict(obs, batch_size=(), device='cpu') obs = TensorDict(obs, batch_size=(), device='cpu')
else: else:
@@ -57,10 +59,13 @@ class OnlineTrainer(Trainer):
action = torch.full_like(self.env.rand_act(), float('nan')) action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None: if reward is None:
reward = torch.tensor(float('nan')) reward = torch.tensor(float('nan'))
if terminated is None:
terminated = torch.tensor(float('nan'))
td = TensorDict(dict( td = TensorDict(dict(
obs=obs, obs=obs,
action=action.unsqueeze(0), action=action.unsqueeze(0),
reward=reward.unsqueeze(0), reward=reward.unsqueeze(0),
terminated=terminated.unsqueeze(0),
), batch_size=(1,)) ), batch_size=(1,))
return td return td
@@ -88,6 +93,7 @@ class OnlineTrainer(Trainer):
) )
train_metrics.update(self.common_metrics()) train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train') self.logger.log(train_metrics, 'train')
self._tds.append(self.to_td()) # Separate episodes with NaNs
self._ep_idx = self.buffer.add(torch.cat(self._tds)) self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset() obs = self.env.reset()
@@ -99,7 +105,7 @@ class OnlineTrainer(Trainer):
else: else:
action = self.env.rand_act() action = self.env.rand_act()
obs, reward, done, info = self.env.step(action) obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward)) self._tds.append(self.to_td(obs, action, reward, info['terminated']))
# Update agent # Update agent
if self._step >= self.cfg.seed_steps: if self._step >= self.cfg.seed_steps: