argmax policy works
This commit is contained in:
@@ -156,16 +156,34 @@ class WorldModel(nn.Module):
|
|||||||
with logits predicted by a neural network.
|
with logits predicted by a neural network.
|
||||||
"""
|
"""
|
||||||
# Categorical policy prior
|
# Categorical policy prior
|
||||||
logits = self._pi(z)
|
# logits = self._pi(z)
|
||||||
policy_dist = Categorical(logits=logits)
|
# policy_dist = Categorical(logits=logits)
|
||||||
action = policy_dist.sample()
|
# action = policy_dist.sample()
|
||||||
|
# action = math.int_to_one_hot(action, self.cfg.action_dim)
|
||||||
|
|
||||||
|
# # Action probabilities for calculating the adapted soft-Q loss
|
||||||
|
# action_probs = policy_dist.probs
|
||||||
|
# log_prob = F.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
# return action, action, log_prob, action_probs
|
||||||
|
|
||||||
|
# Argmax policy
|
||||||
|
# enumerate all possible one-hot actions
|
||||||
|
# and return the one with the highest Q-value
|
||||||
|
# for the given state.
|
||||||
|
actions = torch.eye(self.cfg.action_dim, device=z.device).unsqueeze(0)
|
||||||
|
if z.dim() == 2:
|
||||||
|
# z (batch_size, latent_dim) -> (batch_size, action_dim, latent_dim)
|
||||||
|
z = z.unsqueeze(1).expand(-1, self.cfg.action_dim, -1)
|
||||||
|
elif z.dim() == 3:
|
||||||
|
# z (seq_len, batch_size, latent_dim) -> (seq_len, batch_size, action_dim, latent_dim)
|
||||||
|
z = z.unsqueeze(2).expand(-1, -1, self.cfg.action_dim, -1)
|
||||||
|
actions = actions.unsqueeze(0).repeat(z.shape[0], z.shape[1], 1, 1)
|
||||||
|
Q = self.Q(z, actions, task, return_type='min')
|
||||||
|
action = Q.argmax(dim=-2)
|
||||||
action = math.int_to_one_hot(action, self.cfg.action_dim)
|
action = math.int_to_one_hot(action, self.cfg.action_dim)
|
||||||
|
|
||||||
# Action probabilities for calculating the adapted soft-Q loss
|
return action, action, None, None
|
||||||
action_probs = policy_dist.probs
|
|
||||||
log_prob = F.log_softmax(logits, dim=-1)
|
|
||||||
|
|
||||||
return action, action, log_prob, action_probs
|
|
||||||
|
|
||||||
|
|
||||||
def pi(self, z, task):
|
def pi(self, z, task):
|
||||||
|
|||||||
@@ -107,6 +107,8 @@ class TDMPC2(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
z = self.model.encode(obs, task)
|
z = self.model.encode(obs, task)
|
||||||
action = self.model.pi(z, task)[int(not eval_mode)][0]
|
action = self.model.pi(z, task)[int(not eval_mode)][0]
|
||||||
|
if self.cfg.action == 'discrete':
|
||||||
|
action = action.squeeze(0) # TODO: this is a bit hacky
|
||||||
return action.cpu()
|
return action.cpu()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -234,6 +236,8 @@ class TDMPC2(torch.nn.Module):
|
|||||||
torch.Tensor: TD-target.
|
torch.Tensor: TD-target.
|
||||||
"""
|
"""
|
||||||
pi = self.model.pi(next_z, task)[1]
|
pi = self.model.pi(next_z, task)[1]
|
||||||
|
if self.cfg.action == 'discrete':
|
||||||
|
pi = pi.squeeze(2) # TODO: this is a bit hacky
|
||||||
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
discount = self.discount[task].unsqueeze(-1) if self.cfg.multitask else self.discount
|
||||||
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
return reward + discount * self.model.Q(next_z, pi, task, return_type='min', target=True)
|
||||||
|
|
||||||
@@ -284,7 +288,10 @@ class TDMPC2(torch.nn.Module):
|
|||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# Update policy
|
# Update policy
|
||||||
pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task)
|
if self.cfg.action == 'continuous':
|
||||||
|
pi_loss, pi_grad_norm = self.update_pi(zs.detach(), task)
|
||||||
|
else:
|
||||||
|
pi_loss, pi_grad_norm = 0., 0.
|
||||||
|
|
||||||
# Update target Q-functions
|
# Update target Q-functions
|
||||||
self.model.soft_update_target_Q()
|
self.model.soft_update_target_Q()
|
||||||
|
|||||||
@@ -98,6 +98,11 @@ class OnlineTrainer(Trainer):
|
|||||||
action = self.agent.act(obs, t0=len(self._tds)==1)
|
action = self.agent.act(obs, t0=len(self._tds)==1)
|
||||||
else:
|
else:
|
||||||
action = self.env.rand_act()
|
action = self.env.rand_act()
|
||||||
|
if self.cfg.action == 'discrete':
|
||||||
|
# exploration schedule
|
||||||
|
# minimum 0.01, maximum 0.05, anneal over 20k steps
|
||||||
|
if torch.rand(1) < 0.01 + (0.05 - 0.01) * min(1, self._step / 20000):
|
||||||
|
action = self.env.rand_act()
|
||||||
obs, reward, done, info = self.env.step(action)
|
obs, reward, done, info = self.env.step(action)
|
||||||
self._tds.append(self.to_td(obs, action, reward))
|
self._tds.append(self.to_td(obs, action, reward))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user