add discrete planning
This commit is contained in:
@@ -175,6 +175,7 @@ class WorldModel(nn.Module):
|
||||
if z.dim() == 2:
|
||||
# z (batch_size, latent_dim) -> (batch_size, action_dim, latent_dim)
|
||||
z = z.unsqueeze(1).expand(-1, self.cfg.action_dim, -1)
|
||||
actions = actions.repeat(z.shape[0], 1, 1)
|
||||
elif z.dim() == 3:
|
||||
# z (seq_len, batch_size, latent_dim) -> (seq_len, batch_size, action_dim, latent_dim)
|
||||
z = z.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)
|
||||
|
||||
@@ -30,7 +30,7 @@ exp_name: default
|
||||
data_dir: ???
|
||||
|
||||
# planning
|
||||
mpc: false
|
||||
mpc: true
|
||||
iterations: 6
|
||||
num_samples: 512
|
||||
num_elites: 64
|
||||
|
||||
@@ -121,7 +121,10 @@ class TDMPC2(torch.nn.Module):
|
||||
G = G + discount * reward
|
||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||
discount = discount * discount_update
|
||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
||||
pi = self.model.pi(z, task)[1]
|
||||
if self.cfg.action == 'discrete':
|
||||
pi = pi.squeeze(1) # TODO: this is a bit hacky
|
||||
return G + discount * self.model.Q(z, pi, task, return_type='avg')
|
||||
|
||||
@torch.no_grad()
|
||||
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||
@@ -143,12 +146,19 @@ class TDMPC2(torch.nn.Module):
|
||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
||||
for t in range(self.cfg.horizon-1):
|
||||
pi_actions[t] = self.model.pi(_z, task)[1]
|
||||
action = self.model.pi(_z, task)[1]
|
||||
if self.cfg.action == 'discrete':
|
||||
action = action.squeeze(1)
|
||||
pi_actions[t] = action
|
||||
_z = self.model.next(_z, pi_actions[t], task)
|
||||
pi_actions[-1] = self.model.pi(_z, task)[1]
|
||||
action = self.model.pi(_z, task)[1]
|
||||
if self.cfg.action == 'discrete':
|
||||
action = action.squeeze(1)
|
||||
pi_actions[-1] = action
|
||||
|
||||
# Initialize state and parameters
|
||||
z = z.repeat(self.cfg.num_samples, 1)
|
||||
if self.cfg.action == 'continuous':
|
||||
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
|
||||
if not t0:
|
||||
@@ -157,6 +167,25 @@ class TDMPC2(torch.nn.Module):
|
||||
if self.cfg.num_pi_trajs > 0:
|
||||
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
||||
|
||||
# Random shooting
|
||||
if self.cfg.action == 'discrete':
|
||||
# Sample actions
|
||||
actions_sample = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs), device=actions.device)
|
||||
actions[:, self.cfg.num_pi_trajs:] = math.int_to_one_hot(actions_sample, self.cfg.action_dim)
|
||||
|
||||
# Compute elite actions
|
||||
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
||||
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||
|
||||
# Sample action according to score
|
||||
max_value = elite_value.max(0).values
|
||||
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
||||
score = score / score.sum(0)
|
||||
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
||||
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
||||
return actions[0]
|
||||
|
||||
# Iterate MPPI
|
||||
for _ in range(self.cfg.iterations):
|
||||
|
||||
@@ -191,6 +220,7 @@ class TDMPC2(torch.nn.Module):
|
||||
if not eval_mode:
|
||||
a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
|
||||
self._prev_mean.copy_(mean)
|
||||
|
||||
return a.clamp(-1, 1)
|
||||
|
||||
def update_pi(self, zs, task):
|
||||
|
||||
Reference in New Issue
Block a user