clean up discrete planning
This commit is contained in:
123
tdmpc2/tdmpc2.py
123
tdmpc2/tdmpc2.py
@@ -106,9 +106,8 @@ class TDMPC2(torch.nn.Module):
|
|||||||
action = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
|
action = self.plan(obs, t0=t0, eval_mode=eval_mode, task=task)
|
||||||
else:
|
else:
|
||||||
z = self.model.encode(obs, task)
|
z = self.model.encode(obs, task)
|
||||||
action = self.model.pi(z, task)[int(not eval_mode)][0]
|
select_idx = int(not eval_mode or self.cfg.action_space == 'discrete')
|
||||||
if self.cfg.action_space == 'discrete':
|
action = self.model.pi(z, task)[select_idx][0]
|
||||||
action = action.squeeze(0) # TODO: this is a bit hacky
|
|
||||||
return action.cpu()
|
return action.cpu()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -126,6 +125,33 @@ class TDMPC2(torch.nn.Module):
|
|||||||
pi = pi.squeeze(1) # TODO: this is a bit hacky
|
pi = pi.squeeze(1) # TODO: this is a bit hacky
|
||||||
return G + discount * self.model.Q(z, pi, task, return_type='avg')
|
return G + discount * self.model.Q(z, pi, task, return_type='avg')
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _sample_policy(self, z, task):
|
||||||
|
"""Sample trajectories from the policy prior."""
|
||||||
|
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||||
|
for t in range(self.cfg.horizon-1):
|
||||||
|
action = self.model.pi(z, task)[1]
|
||||||
|
if self.cfg.action_space == 'discrete':
|
||||||
|
action = action.squeeze(1)
|
||||||
|
pi_actions[t] = action
|
||||||
|
z = self.model.next(z, pi_actions[t], task)
|
||||||
|
action = self.model.pi(z, task)[1]
|
||||||
|
if self.cfg.action_space == 'discrete':
|
||||||
|
action = action.squeeze(1)
|
||||||
|
pi_actions[-1] = action
|
||||||
|
return pi_actions
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def _sample_actions(self, n, mean=None, std=None):
|
||||||
|
"""Sample actions from a Gaussian or Categorical distribution."""
|
||||||
|
if self.cfg.action_space == 'discrete':
|
||||||
|
actions = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, n), device=self.device)
|
||||||
|
actions = math.int_to_one_hot(actions, self.cfg.action_dim)
|
||||||
|
else:
|
||||||
|
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
|
||||||
|
actions = (mean.unsqueeze(1) + std.unsqueeze(1) * r).clamp(-1, 1)
|
||||||
|
return actions
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||||
"""
|
"""
|
||||||
@@ -140,88 +166,61 @@ class TDMPC2(torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Action to take in the environment.
|
torch.Tensor: Action to take in the environment.
|
||||||
"""
|
"""
|
||||||
# Sample policy trajectories
|
# Encode observation
|
||||||
z = self.model.encode(obs, task)
|
z = self.model.encode(obs, task)
|
||||||
if self.cfg.num_pi_trajs > 0:
|
|
||||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
|
||||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
|
||||||
for t in range(self.cfg.horizon-1):
|
|
||||||
action = self.model.pi(_z, task)[1]
|
|
||||||
if self.cfg.action_space == 'discrete':
|
|
||||||
action = action.squeeze(1)
|
|
||||||
pi_actions[t] = action
|
|
||||||
_z = self.model.next(_z, pi_actions[t], task)
|
|
||||||
action = self.model.pi(_z, task)[1]
|
|
||||||
if self.cfg.action_space == 'discrete':
|
|
||||||
action = action.squeeze(1)
|
|
||||||
pi_actions[-1] = action
|
|
||||||
|
|
||||||
# Initialize state and parameters
|
|
||||||
z = z.repeat(self.cfg.num_samples, 1)
|
z = z.repeat(self.cfg.num_samples, 1)
|
||||||
|
|
||||||
|
# Initialize parameters
|
||||||
if self.cfg.action_space == 'continuous':
|
if self.cfg.action_space == 'continuous':
|
||||||
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||||
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
|
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
|
||||||
if not t0:
|
if not t0:
|
||||||
mean[:-1] = self._prev_mean[1:]
|
mean[:-1] = self._prev_mean[1:]
|
||||||
|
else:
|
||||||
|
mean, std = None, None
|
||||||
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
actions = torch.empty(self.cfg.horizon, self.cfg.num_samples, self.cfg.action_dim, device=self.device)
|
||||||
|
|
||||||
|
# Sample policy trajectories
|
||||||
if self.cfg.num_pi_trajs > 0:
|
if self.cfg.num_pi_trajs > 0:
|
||||||
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
actions[:, :self.cfg.num_pi_trajs] = self._sample_policy(z[:self.cfg.num_pi_trajs], task)
|
||||||
|
|
||||||
# Random shooting
|
|
||||||
if self.cfg.action_space == 'discrete':
|
|
||||||
# Sample actions
|
|
||||||
actions_sample = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs), device=actions.device)
|
|
||||||
actions[:, self.cfg.num_pi_trajs:] = math.int_to_one_hot(actions_sample, self.cfg.action_dim)
|
|
||||||
|
|
||||||
# Compute elite actions
|
|
||||||
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
|
||||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
|
||||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
|
||||||
|
|
||||||
# Sample action according to score
|
|
||||||
max_value = elite_value.max(0).values
|
|
||||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
|
||||||
score = score / score.sum(0)
|
|
||||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
|
||||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
|
||||||
return actions[0]
|
|
||||||
|
|
||||||
# Iterate MPPI
|
# Iterate MPPI
|
||||||
for _ in range(self.cfg.iterations):
|
for _ in range(self.cfg.iterations):
|
||||||
|
|
||||||
# Sample actions
|
# Sample random actions
|
||||||
r = torch.randn(self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs, self.cfg.action_dim, device=std.device)
|
actions[:, self.cfg.num_pi_trajs:] = self._sample_actions(self.cfg.num_samples-self.cfg.num_pi_trajs, mean, std)
|
||||||
actions_sample = mean.unsqueeze(1) + std.unsqueeze(1) * r
|
|
||||||
actions_sample = actions_sample.clamp(-1, 1)
|
|
||||||
actions[:, self.cfg.num_pi_trajs:] = actions_sample
|
|
||||||
if self.cfg.multitask:
|
if self.cfg.multitask:
|
||||||
actions = actions * self.model._action_masks[task]
|
actions = actions * self.model._action_masks[task]
|
||||||
|
|
||||||
# Compute elite actions
|
# Select elites and compute scores
|
||||||
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
||||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||||
|
|
||||||
# Update parameters
|
|
||||||
max_value = elite_value.max(0).values
|
max_value = elite_value.max(0).values
|
||||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
||||||
score = score / score.sum(0)
|
score = score / score.sum(0)
|
||||||
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
|
|
||||||
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
|
|
||||||
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
|
|
||||||
if self.cfg.multitask:
|
|
||||||
mean = mean * self.model._action_masks[task]
|
|
||||||
std = std * self.model._action_masks[task]
|
|
||||||
|
|
||||||
|
# Update parameters
|
||||||
|
if self.cfg.action_space == 'continuous':
|
||||||
|
mean = (score.unsqueeze(0) * elite_actions).sum(dim=1) / (score.sum(0) + 1e-9)
|
||||||
|
std = ((score.unsqueeze(0) * (elite_actions - mean.unsqueeze(1)) ** 2).sum(dim=1) / (score.sum(0) + 1e-9)).sqrt()
|
||||||
|
std = std.clamp(self.cfg.min_std, self.cfg.max_std)
|
||||||
|
if self.cfg.multitask:
|
||||||
|
mean = mean * self.model._action_masks[task]
|
||||||
|
std = std * self.model._action_masks[task]
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
# Select action
|
# Select action
|
||||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
||||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
action = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)[0]
|
||||||
a, std = actions[0], std[0]
|
if self.cfg.action_space == 'continuous':
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
|
action = action + std[0] * torch.randn(self.cfg.action_dim, device=std.device)
|
||||||
self._prev_mean.copy_(mean)
|
self._prev_mean.copy_(mean)
|
||||||
|
action = action.clamp(-1, 1)
|
||||||
return a.clamp(-1, 1)
|
|
||||||
|
return action
|
||||||
|
|
||||||
def update_pi(self, zs, task):
|
def update_pi(self, zs, task):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user