init
This commit is contained in:
@@ -78,8 +78,9 @@ class Buffer():
|
|||||||
obs = td['obs']
|
obs = td['obs']
|
||||||
action = td['action'][1:]
|
action = td['action'][1:]
|
||||||
reward = td['reward'][1:].unsqueeze(-1)
|
reward = td['reward'][1:].unsqueeze(-1)
|
||||||
|
terminated = td['terminated'][1:].unsqueeze(-1)
|
||||||
task = td['task'][0] if 'task' in td.keys() else None
|
task = td['task'][0] if 'task' in td.keys() else None
|
||||||
return self._to_device(obs, action, reward, task)
|
return self._to_device(obs, action, reward, terminated, task)
|
||||||
|
|
||||||
def add(self, td):
|
def add(self, td):
|
||||||
"""Add an episode to the buffer."""
|
"""Add an episode to the buffer."""
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class WorldModel(nn.Module):
|
|||||||
self._encoder = layers.enc(cfg)
|
self._encoder = layers.enc(cfg)
|
||||||
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
self._dynamics = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], cfg.latent_dim, act=layers.SimNorm(cfg))
|
||||||
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
self._reward = layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1))
|
||||||
|
self._terminated = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 1)
|
||||||
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
self._pi = layers.mlp(cfg.latent_dim + cfg.task_dim, 2*[cfg.mlp_dim], 2*cfg.action_dim)
|
||||||
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
self._Qs = layers.Ensemble([layers.mlp(cfg.latent_dim + cfg.action_dim + cfg.task_dim, 2*[cfg.mlp_dim], max(cfg.num_bins, 1), dropout=cfg.dropout) for _ in range(cfg.num_q)])
|
||||||
self.apply(init.weight_init)
|
self.apply(init.weight_init)
|
||||||
@@ -119,6 +120,15 @@ class WorldModel(nn.Module):
|
|||||||
z = torch.cat([z, a], dim=-1)
|
z = torch.cat([z, a], dim=-1)
|
||||||
return self._reward(z)
|
return self._reward(z)
|
||||||
|
|
||||||
|
def terminated(self, z, task):
|
||||||
|
"""
|
||||||
|
Predicts termination signal.
|
||||||
|
"""
|
||||||
|
assert task is None
|
||||||
|
if self.cfg.multitask:
|
||||||
|
z = self.task_emb(z, task)
|
||||||
|
return torch.sigmoid(self._terminated(z))
|
||||||
|
|
||||||
def pi(self, z, task):
|
def pi(self, z, task):
|
||||||
"""
|
"""
|
||||||
Samples an action from the policy prior.
|
Samples an action from the policy prior.
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ steps: 10_000_000
|
|||||||
batch_size: 256
|
batch_size: 256
|
||||||
reward_coef: 0.1
|
reward_coef: 0.1
|
||||||
value_coef: 0.1
|
value_coef: 0.1
|
||||||
|
terminated_coef: 0.1
|
||||||
consistency_coef: 20
|
consistency_coef: 20
|
||||||
rho: 0.5
|
rho: 0.5
|
||||||
lr: 3e-4
|
lr: 3e-4
|
||||||
|
|||||||
@@ -47,9 +47,12 @@ class ManiSkillWrapper(gym.Wrapper):
|
|||||||
def step(self, action):
|
def step(self, action):
|
||||||
reward = 0
|
reward = 0
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
obs, r, _, info = self.env.step(action)
|
obs, r, done, info = self.env.step(action)
|
||||||
reward += r
|
reward += r
|
||||||
return obs, reward, False, info
|
info['terminated'] = done
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
return obs, reward, done, info
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unwrapped(self):
|
def unwrapped(self):
|
||||||
|
|||||||
@@ -37,4 +37,5 @@ class TensorWrapper(gym.Wrapper):
|
|||||||
obs, reward, done, info = self.env.step(action.numpy())
|
obs, reward, done, info = self.env.step(action.numpy())
|
||||||
info = defaultdict(float, info)
|
info = defaultdict(float, info)
|
||||||
info['success'] = float(info['success'])
|
info['success'] = float(info['success'])
|
||||||
|
info['terminated'] = torch.tensor(float(info['terminated']))
|
||||||
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info
|
return self._obs_to_tensor(obs), torch.tensor(reward, dtype=torch.float32), done, info
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class TDMPC2:
|
|||||||
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
{'params': self.model._encoder.parameters(), 'lr': self.cfg.lr*self.cfg.enc_lr_scale},
|
||||||
{'params': self.model._dynamics.parameters()},
|
{'params': self.model._dynamics.parameters()},
|
||||||
{'params': self.model._reward.parameters()},
|
{'params': self.model._reward.parameters()},
|
||||||
|
{'params': self.model._terminated.parameters()},
|
||||||
{'params': self.model._Qs.parameters()},
|
{'params': self.model._Qs.parameters()},
|
||||||
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []}
|
{'params': self.model._task_emb.parameters() if self.cfg.multitask else []}
|
||||||
], lr=self.cfg.lr)
|
], lr=self.cfg.lr)
|
||||||
@@ -95,12 +96,14 @@ class TDMPC2:
|
|||||||
def _estimate_value(self, z, actions, task):
|
def _estimate_value(self, z, actions, task):
|
||||||
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
"""Estimate value of a trajectory starting at latent state z and executing given actions."""
|
||||||
G, discount = 0, 1
|
G, discount = 0, 1
|
||||||
|
terminated = torch.zeros(self.cfg.num_samples, 1, dtype=torch.float32, device=z.device)
|
||||||
for t in range(self.cfg.horizon):
|
for t in range(self.cfg.horizon):
|
||||||
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
reward = math.two_hot_inv(self.model.reward(z, actions[t], task), self.cfg)
|
||||||
z = self.model.next(z, actions[t], task)
|
z = self.model.next(z, actions[t], task)
|
||||||
G += discount * reward
|
G += discount * (1-terminated) * reward
|
||||||
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
discount *= self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
terminated = torch.clip_(terminated + (self.model.terminated(z, task) > 0.5).float(), max=1.)
|
||||||
|
return G + discount * (1-terminated) * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def plan(self, z, t0=False, eval_mode=False, task=None):
|
def plan(self, z, t0=False, eval_mode=False, task=None):
|
||||||
@@ -199,13 +202,14 @@ class TDMPC2:
|
|||||||
return pi_loss.item()
|
return pi_loss.item()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _td_target(self, next_z, reward, task):
|
def _td_target(self, next_z, reward, terminated, task):
|
||||||
"""
|
"""
|
||||||
Compute the TD-target from a reward and the observation at the following time step.
|
Compute the TD-target from a reward and the observation at the following time step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
next_z (torch.Tensor): Latent state at the following time step.
|
next_z (torch.Tensor): Latent state at the following time step.
|
||||||
reward (torch.Tensor): Reward at the current time step.
|
reward (torch.Tensor): Reward at the current time step.
|
||||||
|
terminated (torch.Tensor): Termination signal at the current time step.
|
||||||
task (torch.Tensor): Task index (only used for multi-task experiments).
|
task (torch.Tensor): Task index (only used for multi-task experiments).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -213,7 +217,7 @@ class TDMPC2:
|
|||||||
"""
|
"""
|
||||||
pi = self.model.pi(next_z, task)[1]
|
pi = self.model.pi(next_z, task)[1]
|
||||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||||
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
return reward + discount * (1-terminated) * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
||||||
|
|
||||||
def update(self, buffer):
|
def update(self, buffer):
|
||||||
"""
|
"""
|
||||||
@@ -225,12 +229,12 @@ class TDMPC2:
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Dictionary of training statistics.
|
dict: Dictionary of training statistics.
|
||||||
"""
|
"""
|
||||||
obs, action, reward, task = buffer.sample()
|
obs, action, reward, terminated, task = buffer.sample()
|
||||||
|
|
||||||
# Compute targets
|
# Compute targets
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
next_z = self.model.encode(obs[1:], task)
|
next_z = self.model.encode(obs[1:], task)
|
||||||
td_targets = self._td_target(next_z, reward, task)
|
td_targets = self._td_target(next_z, reward, terminated, task)
|
||||||
|
|
||||||
# Prepare for update
|
# Prepare for update
|
||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
@@ -250,19 +254,23 @@ class TDMPC2:
|
|||||||
_zs = zs[:-1]
|
_zs = zs[:-1]
|
||||||
qs = self.model.Q(_zs, action, task, return_type='all')
|
qs = self.model.Q(_zs, action, task, return_type='all')
|
||||||
reward_preds = self.model.reward(_zs, action, task)
|
reward_preds = self.model.reward(_zs, action, task)
|
||||||
|
terminated_preds = self.model.terminated(_zs, task)
|
||||||
|
|
||||||
# Compute losses
|
# Compute losses
|
||||||
reward_loss, value_loss = 0, 0
|
reward_loss, terminated_loss, value_loss = 0, 0, 0
|
||||||
for t in range(self.cfg.horizon):
|
for t in range(self.cfg.horizon):
|
||||||
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
|
reward_loss += math.soft_ce(reward_preds[t], reward[t], self.cfg).mean() * self.cfg.rho**t
|
||||||
|
terminated_loss += F.binary_cross_entropy(terminated_preds[t], terminated[t]) * self.cfg.rho**t
|
||||||
for q in range(self.cfg.num_q):
|
for q in range(self.cfg.num_q):
|
||||||
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
|
value_loss += math.soft_ce(qs[q][t], td_targets[t], self.cfg).mean() * self.cfg.rho**t
|
||||||
consistency_loss *= (1/self.cfg.horizon)
|
consistency_loss *= (1/self.cfg.horizon)
|
||||||
reward_loss *= (1/self.cfg.horizon)
|
reward_loss *= (1/self.cfg.horizon)
|
||||||
|
terminated_loss *= (1/self.cfg.horizon)
|
||||||
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
|
value_loss *= (1/(self.cfg.horizon * self.cfg.num_q))
|
||||||
total_loss = (
|
total_loss = (
|
||||||
self.cfg.consistency_coef * consistency_loss +
|
self.cfg.consistency_coef * consistency_loss +
|
||||||
self.cfg.reward_coef * reward_loss +
|
self.cfg.reward_coef * reward_loss +
|
||||||
|
self.cfg.terminated_coef * terminated_loss +
|
||||||
self.cfg.value_coef * value_loss
|
self.cfg.value_coef * value_loss
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -282,6 +290,7 @@ class TDMPC2:
|
|||||||
return {
|
return {
|
||||||
"consistency_loss": float(consistency_loss.mean().item()),
|
"consistency_loss": float(consistency_loss.mean().item()),
|
||||||
"reward_loss": float(reward_loss.mean().item()),
|
"reward_loss": float(reward_loss.mean().item()),
|
||||||
|
"terminated_loss": float(terminated_loss.mean().item()),
|
||||||
"value_loss": float(value_loss.mean().item()),
|
"value_loss": float(value_loss.mean().item()),
|
||||||
"pi_loss": pi_loss,
|
"pi_loss": pi_loss,
|
||||||
"total_loss": float(total_loss.mean().item()),
|
"total_loss": float(total_loss.mean().item()),
|
||||||
|
|||||||
@@ -47,8 +47,10 @@ class OnlineTrainer(Trainer):
|
|||||||
episode_success=np.nanmean(ep_successes),
|
episode_success=np.nanmean(ep_successes),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_td(self, obs, action=None, reward=None):
|
def to_td(self, obs=None, action=None, reward=None, terminated=None):
|
||||||
"""Creates a TensorDict for a new episode."""
|
"""Creates a TensorDict for a new episode."""
|
||||||
|
if obs is None:
|
||||||
|
obs = torch.full((*self.cfg.obs_shape[self.cfg.obs],), float('nan'))
|
||||||
if isinstance(obs, dict):
|
if isinstance(obs, dict):
|
||||||
obs = TensorDict(obs, batch_size=(), device='cpu')
|
obs = TensorDict(obs, batch_size=(), device='cpu')
|
||||||
else:
|
else:
|
||||||
@@ -57,10 +59,13 @@ class OnlineTrainer(Trainer):
|
|||||||
action = torch.full_like(self.env.rand_act(), float('nan'))
|
action = torch.full_like(self.env.rand_act(), float('nan'))
|
||||||
if reward is None:
|
if reward is None:
|
||||||
reward = torch.tensor(float('nan'))
|
reward = torch.tensor(float('nan'))
|
||||||
|
if terminated is None:
|
||||||
|
terminated = torch.tensor(float('nan'))
|
||||||
td = TensorDict(dict(
|
td = TensorDict(dict(
|
||||||
obs=obs,
|
obs=obs,
|
||||||
action=action.unsqueeze(0),
|
action=action.unsqueeze(0),
|
||||||
reward=reward.unsqueeze(0),
|
reward=reward.unsqueeze(0),
|
||||||
|
terminated=terminated.unsqueeze(0),
|
||||||
), batch_size=(1,))
|
), batch_size=(1,))
|
||||||
return td
|
return td
|
||||||
|
|
||||||
@@ -88,6 +93,7 @@ class OnlineTrainer(Trainer):
|
|||||||
)
|
)
|
||||||
train_metrics.update(self.common_metrics())
|
train_metrics.update(self.common_metrics())
|
||||||
self.logger.log(train_metrics, 'train')
|
self.logger.log(train_metrics, 'train')
|
||||||
|
self._tds.append(self.to_td()) # Separate episodes with NaNs
|
||||||
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
self._ep_idx = self.buffer.add(torch.cat(self._tds))
|
||||||
|
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
@@ -99,7 +105,7 @@ class OnlineTrainer(Trainer):
|
|||||||
else:
|
else:
|
||||||
action = self.env.rand_act()
|
action = self.env.rand_act()
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
self._tds.append(self.to_td(obs, action, reward))
|
self._tds.append(self.to_td(obs, action, reward, info['terminated']))
|
||||||
|
|
||||||
# Update agent
|
# Update agent
|
||||||
if self._step >= self.cfg.seed_steps:
|
if self._step >= self.cfg.seed_steps:
|
||||||
|
|||||||
Reference in New Issue
Block a user