init
This commit is contained in:
@@ -78,8 +78,9 @@ class Buffer():
|
||||
obs = td['obs']
|
||||
action = td['action'][1:]
|
||||
reward = td['reward'][1:].unsqueeze(-1)
|
||||
terminated = td['terminated'][1:].unsqueeze(-1)
|
||||
task = td['task'][0] if 'task' in td.keys() else None
|
||||
return self._to_device(obs, action, reward, task)
|
||||
return self._to_device(obs, action, reward, terminated, task)
|
||||
|
||||
def add(self, td):
|
||||
"""Add an episode to the buffer."""
|
||||
|
||||
@@ -24,6 +24,7 @@ class WorldModel(nn.Module):
|
||||
self._encoder = layers.enc(cfg)
|
||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||
self._terminated = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
|
||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||
self.apply(init.weight_init)
|
||||
@@ -118,6 +119,15 @@ class WorldModel(nn.Module):
|
||||
z = self.task_emb(z, task)
|
||||
z = torch.cat([z, a], dim=-1)
|
||||
return self._reward(z)
|
||||
|
||||
def terminated(self, z, task):
|
||||
"""
|
||||
Predicts termination signal.
|
||||
"""
|
||||
assert task is None
|
||||
if self.cfg.multitask:
|
||||
z = self.task_emb(z, task)
|
||||
return torch.sigmoid(self._terminated(z))
|
||||
|
||||
def pi(self, z, task):
|
||||
"""
|
||||
|
||||
@@ -15,6 +15,7 @@ steps: 10_000_000
|
||||
batch_size: 256
|
||||
reward_coef: 0.1
|
||||
value_coef: 0.1
|
||||
terminated_coef: 0.1
|
||||
consistency_coef: 20
|
||||
rho: 0.5
|
||||
lr: 3e-4
|
||||
|
||||
@@ -47,9 +47,12 @@ class ManiSkillWrapper(gym.Wrapper):
|
||||
def step(self, action):
|
||||
reward = 0
|
||||
for _ in range(2):
|
||||
obs, r, _, info = self.env.step(action)
|
||||
obs, r, done, info = self.env.step(action)
|
||||
reward += r
|
||||
return obs, reward, False, info
|
||||
info['terminated'] = done
|
||||
if done:
|
||||
break
|
||||
return obs, reward, done, info
|
||||
|
||||
@property
|
||||
def unwrapped(self):
|
||||
|
||||
@@ -37,4 +37,5 @@ class TensorWrapper(gym.Wrapper):
|
||||
obs, reward, done, info = self.env.step(action.numpy())
|
||||
info = defaultdict(float, info)
|
||||
info['success'] = float(info['success'])
|
||||
info['terminated'] = torch.tensor(float(info['terminated']))
|
||||
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info
|
||||
|
||||
@@ -22,6 +22,7 @@ class TDMPC2:
|
||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||
{'params': self.model._dynamics.parameters()},
|
||||
{'params': self.model._reward.parameters()},
|
||||
{'params': self.model._terminated.parameters()},
|
||||
{'params': self.model._Qs.parameters()},
|
||||
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []}
|
||||
], lr=self.cfg.lr)
|
||||
@@ -95,12 +96,14 @@ class TDMPC2:
|
||||
def _estimate_value(self, z, actions, task):
|
||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||
G, discount = 0, 1
|
||||
terminated = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
|
||||
for t in range(self.cfg.horizon):
|
||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||
z = self.model.next(z, actions[t], task)
|
||||
G += discount * reward
|
||||
G += discount * (1-terminated) * reward
|
||||
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||
terminated = torch.clip_(terminated + (self.model.terminated(z, task) > 0.5).float(), max=1.)
|
||||
return G + discount * (1-terminated) * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z, t0=False, eval_mode=False, task=None):
|
||||
@@ -199,13 +202,14 @@ class TDMPC2:
|
||||
return pi_loss.item()
|
||||
|
||||
@torch.no_grad()
|
||||
def _td_target(self, next_z, reward, task):
|
||||
def _td_target(self, next_z, reward, terminated, task):
|
||||
"""
|
||||
Compute the TD-target from a reward and the observation at the following time step.
|
||||
|
||||
Args:
|
||||
next_z (torch.Tensor): Latent state at the following time step.
|
||||
reward (torch.Tensor): Reward at the current time step.
|
||||
terminated (torch.Tensor): Termination signal at the current time step.
|
||||
task (torch.Tensor): Task index (only used for multi-task experiments).
|
||||
|
||||
Returns:
|
||||
@@ -213,7 +217,7 @@ class TDMPC2:
|
||||
"""
|
||||
pi = self.model.pi(next_z, task)[1]
|
||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
||||
return reward + discount * (1-terminated) * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
||||
|
||||
def update(self, buffer):
|
||||
"""
|
||||
@@ -225,12 +229,12 @@ class TDMPC2:
|
||||
Returns:
|
||||
dict: Dictionary of training statistics.
|
||||
"""
|
||||
obs, action, reward, task = buffer.sample()
|
||||
obs, action, reward, terminated, task = buffer.sample()
|
||||
|
||||
# Compute targets
|
||||
with torch.no_grad():
|
||||
next_z = self.model.encode(obs[1:], task)
|
||||
td_targets = self._td_target(next_z, reward, task)
|
||||
td_targets = self._td_target(next_z, reward, terminated, task)
|
||||
|
||||
# Prepare for update
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
@@ -250,19 +254,23 @@ class TDMPC2:
|
||||
_zs = zs[:-1]
|
||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||
reward_preds = self.model.reward(_zs, action, task)
|
||||
terminated_preds = self.model.terminated(_zs, task)
|
||||
|
||||
# Compute losses
|
||||
reward_loss, value_loss = 0, 0
|
||||
reward_loss, terminated_loss, value_loss = 0, 0, 0
|
||||
for t in range(self.cfg.horizon):
|
||||
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
|
||||
terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t
|
||||
for q in range(self.cfg.num_q):
|
||||
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
|
||||
consistency_loss *= (1/self.cfg.horizon)
|
||||
reward_loss *= (1/self.cfg.horizon)
|
||||
terminated_loss *= (1/self.cfg.horizon)
|
||||
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
|
||||
total_loss = (
|
||||
self.cfg.consistency_coef * consistency_loss +
|
||||
self.cfg.reward_coef * reward_loss +
|
||||
self.cfg.terminated_coef * terminated_loss +
|
||||
self.cfg.value_coef * value_loss
|
||||
)
|
||||
|
||||
@@ -282,6 +290,7 @@ class TDMPC2:
|
||||
return {
|
||||
"consistency_loss": float(consistency_loss.mean().item()),
|
||||
"reward_loss": float(reward_loss.mean().item()),
|
||||
"terminated_loss": float(terminated_loss.mean().item()),
|
||||
"value_loss": float(value_loss.mean().item()),
|
||||
"pi_loss": pi_loss,
|
||||
"total_loss": float(total_loss.mean().item()),
|
||||
|
||||
@@ -47,8 +47,10 @@ class OnlineTrainer(Trainer):
|
||||
episode_success=np.nanmean(ep_successes),
|
||||
)
|
||||
|
||||
def to_td(self, obs, action=None, reward=None):
|
||||
def to_td(self, obs=None, action=None, reward=None, terminated=None):
|
||||
"""Creates a TensorDict for a new episode."""
|
||||
if obs is None:
|
||||
obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan'))
|
||||
if isinstance(obs, dict):
|
||||
obs = TensorDict(obs, batch_size=(), device='cpu')
|
||||
else:
|
||||
@@ -57,12 +59,15 @@ class OnlineTrainer(Trainer):
|
||||
action = torch.full_like(self.env.rand_act(), float('nan'))
|
||||
if reward is None:
|
||||
reward = torch.tensor(float('nan'))
|
||||
if terminated is None:
|
||||
terminated = torch.tensor(float('nan'))
|
||||
td = TensorDict(dict(
|
||||
obs=obs,
|
||||
action=action.unsqueeze(0),
|
||||
reward=reward.unsqueeze(0),
|
||||
terminated=terminated.unsqueeze(0),
|
||||
), batch_size=(1,))
|
||||
return td
|
||||
return td
|
||||
|
||||
def train(self):
|
||||
"""Train a TD-MPC2 agent."""
|
||||
@@ -88,6 +93,7 @@ class OnlineTrainer(Trainer):
|
||||
)
|
||||
train_metrics.update(self.common_metrics())
|
||||
self.logger.log(train_metrics, 'train')
|
||||
self._tds.append(self.to_td()) # Separate episodes with NaNs
|
||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||
|
||||
obs = self.env.reset()
|
||||
@@ -99,7 +105,7 @@ class OnlineTrainer(Trainer):
|
||||
else:
|
||||
action = self.env.rand_act()
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._tds.append(self.to_td(obs, action, reward))
|
||||
self._tds.append(self.to_td(obs, action, reward, info['terminated']))
|
||||
|
||||
# Update agent
|
||||
if self._step >= self.cfg.seed_steps:
|
||||
|
||||
Reference in New Issue
Block a user