This commit is contained in:
Nicklas Hansen
2024-01-07 18:16:33 -08:00
parent 0f3bc77011
commit cc62c4c9ce
7 changed files with 43 additions and 12 deletions

View File

@@ -78,8 +78,9 @@ class Buffer():
obs = td['obs']
action = td['action'][1:]
reward = td['reward'][1:].unsqueeze(-1)
terminated = td['terminated'][1:].unsqueeze(-1)
task = td['task'][0] if 'task' in td.keys() else None
return self._to_device(obs, action, reward, task)
return self._to_device(obs, action, reward, terminated, task)
def add(self, td):
"""Add an episode to the buffer."""

View File

@@ -24,6 +24,7 @@ class WorldModel(nn.Module):
self._encoder = layers.enc(cfg)
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
self._terminated = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
self.apply(init.weight_init)
@@ -118,6 +119,15 @@ class WorldModel(nn.Module):
z = self.task_emb(z, task)
z = torch.cat([z, a], dim=-1)
return self._reward(z)
def terminated(self, z, task):
"""
Predicts termination signal.
"""
assert task is None
if self.cfg.multitask:
z = self.task_emb(z, task)
return torch.sigmoid(self._terminated(z))
def pi(self, z, task):
"""

View File

@@ -15,6 +15,7 @@ steps: 10_000_000
batch_size: 256
reward_coef: 0.1
value_coef: 0.1
terminated_coef: 0.1
consistency_coef: 20
rho: 0.5
lr: 3e-4

View File

@@ -47,9 +47,12 @@ class ManiSkillWrapper(gym.Wrapper):
def step(self, action):
reward = 0
for _ in range(2):
obs, r, _, info = self.env.step(action)
obs, r, done, info = self.env.step(action)
reward += r
return obs, reward, False, info
info['terminated'] = done
if done:
break
return obs, reward, done, info
@property
def unwrapped(self):

View File

@@ -37,4 +37,5 @@ class TensorWrapper(gym.Wrapper):
obs, reward, done, info = self.env.step(action.numpy())
info = defaultdict(float, info)
info['success'] = float(info['success'])
info['terminated'] = torch.tensor(float(info['terminated']))
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info

View File

@@ -22,6 +22,7 @@ class TDMPC2:
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
{'params': self.model._dynamics.parameters()},
{'params': self.model._reward.parameters()},
{'params': self.model._terminated.parameters()},
{'params': self.model._Qs.parameters()},
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []}
], lr=self.cfg.lr)
@@ -95,12 +96,14 @@ class TDMPC2:
def _estimate_value(self, z, actions, task):
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
G, discount = 0, 1
terminated = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
for t in range(self.cfg.horizon):
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
z = self.model.next(z, actions[t], task)
G += discount * reward
G += discount * (1-terminated) * reward
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
terminated = torch.clip_(terminated + (self.model.terminated(z, task) > 0.5).float(), max=1.)
return G + discount * (1-terminated) * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
@torch.no_grad()
def plan(self, z, t0=False, eval_mode=False, task=None):
@@ -199,13 +202,14 @@ class TDMPC2:
return pi_loss.item()
@torch.no_grad()
def _td_target(self, next_z, reward, task):
def _td_target(self, next_z, reward, terminated, task):
"""
Compute the TD-target from a reward and the observation at the following time step.
Args:
next_z (torch.Tensor): Latent state at the following time step.
reward (torch.Tensor): Reward at the current time step.
terminated (torch.Tensor): Termination signal at the current time step.
task (torch.Tensor): Task index (only used for multi-task experiments).
Returns:
@@ -213,7 +217,7 @@ class TDMPC2:
"""
pi = self.model.pi(next_z, task)[1]
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
return reward + discount * (1-terminated) * self.model.Q(next_z, pi, task, return_type='min', target=True)
def update(self, buffer):
"""
@@ -225,12 +229,12 @@ class TDMPC2:
Returns:
dict: Dictionary of training statistics.
"""
obs, action, reward, task = buffer.sample()
obs, action, reward, terminated, task = buffer.sample()
# Compute targets
with torch.no_grad():
next_z = self.model.encode(obs[1:], task)
td_targets = self._td_target(next_z, reward, task)
td_targets = self._td_target(next_z, reward, terminated, task)
# Prepare for update
self.optim.zero_grad(set_to_none=True)
@@ -250,19 +254,23 @@ class TDMPC2:
_zs = zs[:-1]
qs = self.model.Q(_zs, action, task, return_type='all')
reward_preds = self.model.reward(_zs, action, task)
terminated_preds = self.model.terminated(_zs, task)
# Compute losses
reward_loss, value_loss = 0, 0
reward_loss, terminated_loss, value_loss = 0, 0, 0
for t in range(self.cfg.horizon):
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t
for q in range(self.cfg.num_q):
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
consistency_loss *= (1/self.cfg.horizon)
reward_loss *= (1/self.cfg.horizon)
terminated_loss *= (1/self.cfg.horizon)
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
total_loss = (
self.cfg.consistency_coef * consistency_loss +
self.cfg.reward_coef * reward_loss +
self.cfg.terminated_coef * terminated_loss +
self.cfg.value_coef * value_loss
)
@@ -282,6 +290,7 @@ class TDMPC2:
return {
"consistency_loss": float(consistency_loss.mean().item()),
"reward_loss": float(reward_loss.mean().item()),
"terminated_loss": float(terminated_loss.mean().item()),
"value_loss": float(value_loss.mean().item()),
"pi_loss": pi_loss,
"total_loss": float(total_loss.mean().item()),

View File

@@ -47,8 +47,10 @@ class OnlineTrainer(Trainer):
episode_success=np.nanmean(ep_successes),
)
def to_td(self, obs, action=None, reward=None):
def to_td(self, obs=None, action=None, reward=None, terminated=None):
"""Creates a TensorDict for a new episode."""
if obs is None:
obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan'))
if isinstance(obs, dict):
obs = TensorDict(obs, batch_size=(), device='cpu')
else:
@@ -57,10 +59,13 @@ class OnlineTrainer(Trainer):
action = torch.full_like(self.env.rand_act(), float('nan'))
if reward is None:
reward = torch.tensor(float('nan'))
if terminated is None:
terminated = torch.tensor(float('nan'))
td = TensorDict(dict(
obs=obs,
action=action.unsqueeze(0),
reward=reward.unsqueeze(0),
terminated=terminated.unsqueeze(0),
), batch_size=(1,))
return td
@@ -88,6 +93,7 @@ class OnlineTrainer(Trainer):
)
train_metrics.update(self.common_metrics())
self.logger.log(train_metrics, 'train')
self._tds.append(self.to_td()) # Separate episodes with NaNs
self._ep_idx = self.buffer.add(torch.cat(self._tds))
obs = self.env.reset()
@@ -99,7 +105,7 @@ class OnlineTrainer(Trainer):
else:
action = self.env.rand_act()
obs, reward, done, info = self.env.step(action)
self._tds.append(self.to_td(obs, action, reward))
self._tds.append(self.to_td(obs, action, reward, info['terminated']))
# Update agent
if self._step >= self.cfg.seed_steps: