Replace Tanh with ReLU of scores and simplify matrix operations

This commit is contained in:
TJU_Lu
2024-12-22 15:43:04 +08:00
parent d21a2f818b
commit 43cf7bea6d
3 changed files with 45 additions and 90 deletions

View File

@@ -51,10 +51,9 @@ class YopoPolicy(nn.Module):
output_dim = (self.action_dim + 1) * self.lattice_space.vel_num * self.lattice_space.radio_num
# input state dim = hidden_state + vel + acc + goal
input_dim = self.hidden_state + 9
self.image_backbone = YopoBackbone(self.hidden_state,
self.lattice_space.horizon_num * self.lattice_space.vertical_num)
self.image_backbone = YopoBackbone(self.hidden_state, self.lattice_space.horizon_num * self.lattice_space.vertical_num)
self.state_backbone = nn.Sequential()
self.yopo_header = self.create_header(input_dim, output_dim, self.net_arch, self.activation_fn, True)
self.yopo_header = self.create_header(input_dim, output_dim, self.net_arch, self.activation_fn)
self.grad_layer = CostAndGradLayer.apply
# Setup optimizer with initial learning rate
learning_rate = lr_schedule(1) if lr_schedule is not None else 1e-3
@@ -63,26 +62,24 @@ class YopoPolicy(nn.Module):
# TenserRT Transfer
def forward(self, depth: th.Tensor, obs: th.Tensor) -> th.Tensor:
"""
forward propagation of neural network, only used for TensorRT conversion.
forward propagation of neural network, separated for TensorRT conversion.
"""
depth_feature = self.image_backbone(depth)
obs_feature = self.state_backbone(obs)
input_tensor = th.cat((obs_feature, depth_feature), 1)
output = self.yopo_header(input_tensor)
# [batch, endstate+score, lattice_row, lattice_col]
return output
endstate = th.tanh(output[:, :9])
score = th.relu(output[:, 9:])
return th.cat((endstate, score), dim=1) # [batch, endstate+score, lattice_row, lattice_col]
# Training Policy
def inference(self, depth: th.Tensor, obs: th.Tensor) -> th.Tensor:
"""
For network training:
(1) predicted the endstate(end_state) and score
(1) predicted the endstate and score
(2) record the gradients and costs of prediction
"""
depth_feature = self.image_backbone(depth)
obs_feature = self.state_backbone(obs)
input_tensor = th.cat((obs_feature, depth_feature), 1)
output = self.yopo_header(input_tensor)
output = self.forward(depth, obs)
# [batch, endstate+score, lattice_num]
batch_size = obs.shape[0]
@@ -93,7 +90,7 @@ class YopoPolicy(nn.Module):
endstate_score_predictions = th.zeros_like(output).to(self.device)
cost_labels = th.zeros((batch_size, self.lattice_space.horizon_num * self.lattice_space.vertical_num)).to(self.device)
for i in range(0, self.lattice_space.horizon_num * self.lattice_space.vertical_num):
for i in range(self.lattice_space.horizon_num * self.lattice_space.vertical_num):
id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i
ids = id * np.ones((batch_size, 1))
endstate = self.pred_to_endstate(endstate_pred[:, :, i], id)
@@ -106,40 +103,30 @@ class YopoPolicy(nn.Module):
return endstate_score_predictions, cost_labels
# Testing Policy
def predict(self, depth: th.Tensor, obs: th.Tensor, return_all_preds=False) -> th.Tensor:
def predict(self, depth: th.Tensor, obs: th.Tensor) -> th.Tensor:
"""
For network testing:
(1) predicted the endstate(end_state) and score
(1) predicted the endstate and score, and return the optimal
"""
with th.no_grad():
depth_feature = self.image_backbone(depth)
obs_feature = self.state_backbone(obs.float())
input_tensor = th.cat((obs_feature, depth_feature), 1)
output = self.yopo_header(input_tensor)
output = self.forward(depth, obs.float())
batch_size = obs.shape[0]
output = output.view(batch_size, 10, self.lattice_space.horizon_num * self.lattice_space.vertical_num)
endstate_pred = output[:, 0:9, :]
score_pred = output[:, 9, :]
if not return_all_preds:
endstate_prediction = th.zeros(batch_size, self.action_dim)
score_prediction = th.zeros(batch_size, 1)
for i in range(0, batch_size):
action_id = th.argmin(score_pred[i]).item()
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id
endstate_prediction[i] = self.pred_to_endstate(th.unsqueeze(endstate_pred[i, :, action_id], 0), lattice_id)
score_prediction[i] = score_pred[i, action_id]
else:
endstate_prediction = th.zeros_like(endstate_pred)
score_prediction = score_pred
for i in range(0, self.lattice_space.horizon_num * self.lattice_space.vertical_num):
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i
endstate = self.pred_to_endstate(endstate_pred[:, :, i], lattice_id)
endstate_prediction[:, :, i] = endstate
endstate_prediction = th.zeros(batch_size, self.action_dim)
score_prediction = th.zeros(batch_size, 1)
for i in range(batch_size):
action_id = th.argmin(score_pred[i]).item()
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id
endstate_prediction[i] = self.pred_to_endstate(th.unsqueeze(endstate_pred[i, :, action_id], 0), lattice_id)
score_prediction[i] = score_pred[i, action_id]
return endstate_prediction, score_prediction
def pred_to_endstate(self, endstate_pred: th.Tensor, id: int):
def pred_to_endstate(self, endstate_pred: th.Tensor, id: int) -> th.Tensor:
"""
Transform the predicted state to the body frame.
"""
@@ -154,9 +141,9 @@ class YopoPolicy(nn.Module):
endstate_vp = endstate_pred[:, 3:6] * self.lattice_space.vel_max
endstate_ap = endstate_pred[:, 6:9] * self.lattice_space.acc_max
Rbp = self.lattice_primitive.getRotation(id)
endstate_vb = th.matmul(th.tensor(Rbp).to(self.device), endstate_vp.t()).t()
endstate_ab = th.matmul(th.tensor(Rbp).to(self.device), endstate_ap.t()).t()
Rpb = th.tensor(self.lattice_primitive.getRotation(id).T).to(self.device)
endstate_vb = th.matmul(endstate_vp, Rpb)
endstate_ab = th.matmul(endstate_ap, Rpb)
endstate = th.cat((endstate_p, endstate_vb, endstate_ab), dim=1)
endstate[:, [0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[:, [0, 3, 6, 1, 4, 7, 2, 5, 8]]
return endstate
@@ -170,20 +157,18 @@ class YopoPolicy(nn.Module):
) -> nn.Sequential:
if len(net_arch) > 0:
modules = [nn.Conv2d(in_channels=input_dim, out_channels=net_arch[0], kernel_size=1, stride=1, padding=0),
activation_fn()]
modules = [nn.Conv2d(in_channels=input_dim, out_channels=net_arch[0], kernel_size=1, stride=1, padding=0), activation_fn()]
else:
modules = []
for idx in range(len(net_arch) - 1):
modules.append(nn.Conv2d(in_channels=net_arch[idx], out_channels=net_arch[idx + 1], kernel_size=1, stride=1,
padding=0))
modules.append(nn.Conv2d(in_channels=net_arch[idx], out_channels=net_arch[idx + 1], kernel_size=1, stride=1, padding=0))
modules.append(activation_fn())
if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
modules.append(nn.Conv2d(in_channels=last_layer_dim, out_channels=output_dim, kernel_size=1, stride=1,
padding=0))
modules.append(nn.Conv2d(in_channels=last_layer_dim, out_channels=output_dim, kernel_size=1, stride=1, padding=0))
if squash_output:
modules.append(nn.Tanh())
return nn.Sequential(*modules)