add discrete planning
This commit is contained in:
@@ -175,6 +175,7 @@ class WorldModel(nn.Module):
|
|||||||
if z.dim() == 2:
|
if z.dim() == 2:
|
||||||
# z (batch_size, latent_dim) -> (batch_size, action_dim, latent_dim)
|
# z (batch_size, latent_dim) -> (batch_size, action_dim, latent_dim)
|
||||||
z = z.unsqueeze(1).expand(-1, self.cfg.action_dim, -1)
|
z = z.unsqueeze(1).expand(-1, self.cfg.action_dim, -1)
|
||||||
|
actions = actions.repeat(z.shape[0], 1, 1)
|
||||||
elif z.dim() == 3:
|
elif z.dim() == 3:
|
||||||
# z (seq_len, batch_size, latent_dim) -> (seq_len, batch_size, action_dim, latent_dim)
|
# z (seq_len, batch_size, latent_dim) -> (seq_len, batch_size, action_dim, latent_dim)
|
||||||
z = z.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)
|
z = z.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ exp_name: default
|
|||||||
data_dir: ???
|
data_dir: ???
|
||||||
|
|
||||||
# planning
|
# planning
|
||||||
mpc: false
|
mpc: true
|
||||||
iterations: 6
|
iterations: 6
|
||||||
num_samples: 512
|
num_samples: 512
|
||||||
num_elites: 64
|
num_elites: 64
|
||||||
|
|||||||
@@ -121,7 +121,10 @@ class TDMPC2(torch.nn.Module):
|
|||||||
G = G + discount * reward
|
G = G + discount * reward
|
||||||
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
discount_update = self.discount[torch.tensor(task)] if self.cfg.multitask else self.discount
|
||||||
discount = discount * discount_update
|
discount = discount * discount_update
|
||||||
return G + discount * self.model.Q(z, self.model.pi(z, task)[1], task, return_type='avg')
|
pi = self.model.pi(z, task)[1]
|
||||||
|
if self.cfg.action == 'discrete':
|
||||||
|
pi = pi.squeeze(1) # TODO: this is a bit hacky
|
||||||
|
return G + discount * self.model.Q(z, pi, task, return_type='avg')
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
def _plan(self, obs, t0=False, eval_mode=False, task=None):
|
||||||
@@ -143,12 +146,19 @@ class TDMPC2(torch.nn.Module):
|
|||||||
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
pi_actions = torch.empty(self.cfg.horizon, self.cfg.num_pi_trajs, self.cfg.action_dim, device=self.device)
|
||||||
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
_z = z.repeat(self.cfg.num_pi_trajs, 1)
|
||||||
for t in range(self.cfg.horizon-1):
|
for t in range(self.cfg.horizon-1):
|
||||||
pi_actions[t] = self.model.pi(_z, task)[1]
|
action = self.model.pi(_z, task)[1]
|
||||||
|
if self.cfg.action == 'discrete':
|
||||||
|
action = action.squeeze(1)
|
||||||
|
pi_actions[t] = action
|
||||||
_z = self.model.next(_z, pi_actions[t], task)
|
_z = self.model.next(_z, pi_actions[t], task)
|
||||||
pi_actions[-1] = self.model.pi(_z, task)[1]
|
action = self.model.pi(_z, task)[1]
|
||||||
|
if self.cfg.action == 'discrete':
|
||||||
|
action = action.squeeze(1)
|
||||||
|
pi_actions[-1] = action
|
||||||
|
|
||||||
# Initialize state and parameters
|
# Initialize state and parameters
|
||||||
z = z.repeat(self.cfg.num_samples, 1)
|
z = z.repeat(self.cfg.num_samples, 1)
|
||||||
|
if self.cfg.action == 'continuous':
|
||||||
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
mean = torch.zeros(self.cfg.horizon, self.cfg.action_dim, device=self.device)
|
||||||
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
|
std = torch.full((self.cfg.horizon, self.cfg.action_dim), self.cfg.max_std, dtype=torch.float, device=self.device)
|
||||||
if not t0:
|
if not t0:
|
||||||
@@ -157,6 +167,25 @@ class TDMPC2(torch.nn.Module):
|
|||||||
if self.cfg.num_pi_trajs > 0:
|
if self.cfg.num_pi_trajs > 0:
|
||||||
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
actions[:, :self.cfg.num_pi_trajs] = pi_actions
|
||||||
|
|
||||||
|
# Random shooting
|
||||||
|
if self.cfg.action == 'discrete':
|
||||||
|
# Sample actions
|
||||||
|
actions_sample = torch.randint(0, self.cfg.action_dim, (self.cfg.horizon, self.cfg.num_samples-self.cfg.num_pi_trajs), device=actions.device)
|
||||||
|
actions[:, self.cfg.num_pi_trajs:] = math.int_to_one_hot(actions_sample, self.cfg.action_dim)
|
||||||
|
|
||||||
|
# Compute elite actions
|
||||||
|
value = self._estimate_value(z, actions, task).nan_to_num(0)
|
||||||
|
elite_idxs = torch.topk(value.squeeze(1), self.cfg.num_elites, dim=0).indices
|
||||||
|
elite_value, elite_actions = value[elite_idxs], actions[:, elite_idxs]
|
||||||
|
|
||||||
|
# Sample action according to score
|
||||||
|
max_value = elite_value.max(0).values
|
||||||
|
score = torch.exp(self.cfg.temperature*(elite_value - max_value))
|
||||||
|
score = score / score.sum(0)
|
||||||
|
rand_idx = math.gumbel_softmax_sample(score.squeeze(1)) # gumbel_softmax_sample is compatible with cuda graphs
|
||||||
|
actions = torch.index_select(elite_actions, 1, rand_idx).squeeze(1)
|
||||||
|
return actions[0]
|
||||||
|
|
||||||
# Iterate MPPI
|
# Iterate MPPI
|
||||||
for _ in range(self.cfg.iterations):
|
for _ in range(self.cfg.iterations):
|
||||||
|
|
||||||
@@ -191,6 +220,7 @@ class TDMPC2(torch.nn.Module):
|
|||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
|
a = a + std * torch.randn(self.cfg.action_dim, device=std.device)
|
||||||
self._prev_mean.copy_(mean)
|
self._prev_mean.copy_(mean)
|
||||||
|
|
||||||
return a.clamp(-1, 1)
|
return a.clamp(-1, 1)
|
||||||
|
|
||||||
def update_pi(self, zs, task):
|
def update_pi(self, zs, task):
|
||||||
|
|||||||
Reference in New Issue
Block a user