add discrete planning

This commit is contained in:
Nicklas Hansen
2024-11-12 00:13:08 -08:00
parent 8280b82d5c
commit 88ad0620ca
3 changed files with 39 additions and 8 deletions

View File

@@ -175,6 +175,7 @@ class WorldModel(nn.Module):
if z.dim() == 2:
# z (batch_size, latent_dim) -> (batch_size, action_dim, latent_dim)
z = z.unsqueeze(1).expand(-1, self.cfg.action_dim, -1)
actions = actions.repeat(z.shape[0], 1, 1)
elif z.dim() == 3:
# z (seq_len, batch_size, latent_dim) -> (seq_len, batch_size, action_dim, latent_dim)
z = z.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)

View File

@@ -30,7 +30,7 @@ exp_name: default
data_dir: ???
# planning
mpc: false
mpc: true
iterations: 6
num_samples: 512
num_elites: 64

View File

@@ -121,7 +121,10 @@ class TDMPC2(torch.nn.Module):
G = G + discount * reward
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
discount = discount * discount_update
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
pi = self.model.pi(z, task)[1]
if self.cfg.action == 'discrete':
pi = pi.squeeze(1) # TODO: this is a bit hacky
return G + discount * self.model.Q(z, pi, task, return_type='avg')
@torch.no_grad()
def _plan(self, obs, t0=False, eval_mode=False, task=None):
@@ -143,12 +146,19 @@ class TDMPC2(torch.nn.Module):
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
_z = z.repeat(self.cfg.num_pi_trajs, 1)
for t in range(self.cfg.horizon-1):
pi_actions[t] = self.model.pi(_z, task)[1]
action = self.model.pi(_z, task)[1]
if self.cfg.action == 'discrete':
action = action.squeeze(1)
pi_actions[t] = action
_z = self.model.next(_z, pi_actions[t], task)
pi_actions[-1] = self.model.pi(_z, task)[1]
action = self.model.pi(_z, task)[1]
if self.cfg.action == 'discrete':
action = action.squeeze(1)
pi_actions[-1] = action
# Initialize state and parameters
z = z.repeat(self.cfg.num_samples, 1)
if self.cfg.action == 'continuous':
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
if not t0:
@@ -157,6 +167,25 @@ class TDMPC2(torch.nn.Module):
if self.cfg.num_pi_trajs > 0:
actions[:, :self.cfg.num_pi_trajs] = pi_actions
# Random shooting
if self.cfg.action == 'discrete':
# Sample actions
actions_sample = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs), device=actions.device)
actions[:, self.cfg.num_pi_trajs:] = math.int_to_one_hot(actions_sample, self.cfg.action_dim)
# Compute elite actions
value = self._estimate_value(z, actions, task).nan_to_num(0)
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
# Sample action according to score
max_value = elite_value.max(0).values
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
score = score / score.sum(0)
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
return actions[0]
# Iterate MPPI
for _ in range(self.cfg.iterations):
@@ -191,6 +220,7 @@ class TDMPC2(torch.nn.Module):
if not eval_mode:
a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
self._prev_mean.copy_(mean)
return a.clamp(-1, 1)
def update_pi(self, zs, task):