simplify inference node, vectorize NumPy operations, fix timing bug.

This commit is contained in:
TJU_Lu
2024-12-17 21:10:46 +08:00
parent 59364bef31
commit 35cd195a10
2 changed files with 77 additions and 96 deletions

View File

@@ -88,7 +88,7 @@ class YopoNet:
def callback_set_goal(self, data): def callback_set_goal(self, data):
self.goal = np.asarray([data.pose.position.x, data.pose.position.y, 2]) self.goal = np.asarray([data.pose.position.x, data.pose.position.y, 2])
print("New goal:", self.goal) print("New Goal:", self.goal)
# the first frame # the first frame
def callback_odometry(self, data): def callback_odometry(self, data):
@@ -104,19 +104,20 @@ class YopoNet:
self.new_odom = True self.new_odom = True
def process_odom(self): def process_odom(self):
# Rwb # Rwb -> Rwc -> Rcw
Rotation_wb = R.from_quat([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y, Rotation_wb = R.from_quat([self.odom.pose.pose.orientation.x, self.odom.pose.pose.orientation.y,
self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]).as_matrix() self.odom.pose.pose.orientation.z, self.odom.pose.pose.orientation.w]).as_matrix()
self.Rotation_wc = np.dot(Rotation_wb, self.Rotation_bc) self.Rotation_wc = np.dot(Rotation_wb, self.Rotation_bc)
Rotation_cw = self.Rotation_wc.T
if self.odom_ref_init: if self.odom_ref_init:
odom_data = self.odom_ref odom_data = self.odom_ref
# vel_b # vel_b
vel_w = np.array([odom_data.twist.twist.linear.x, odom_data.twist.twist.linear.y, odom_data.twist.twist.linear.z]) vel_w = np.array([odom_data.twist.twist.linear.x, odom_data.twist.twist.linear.y, odom_data.twist.twist.linear.z])
vel_b = np.dot(np.linalg.inv(self.Rotation_wc), vel_w) vel_b = np.dot(Rotation_cw, vel_w)
# acc_b (acc stored in angular in our ref_state topic) # acc_b (acc stored in angular in our ref_state topic)
acc_w = np.array([odom_data.twist.twist.angular.x, odom_data.twist.twist.angular.y, odom_data.twist.twist.angular.z]) acc_w = np.array([odom_data.twist.twist.angular.x, odom_data.twist.twist.angular.y, odom_data.twist.twist.angular.z])
acc_b = np.dot(np.linalg.inv(self.Rotation_wc), acc_w) acc_b = np.dot(Rotation_cw, acc_w)
else: else:
odom_data = self.odom odom_data = self.odom
vel_b = np.array([0.0, 0.0, 0.0]) vel_b = np.array([0.0, 0.0, 0.0])
@@ -125,7 +126,7 @@ class YopoNet:
# pose and goal_dir # pose and goal_dir
pos = np.array([odom_data.pose.pose.position.x, odom_data.pose.pose.position.y, odom_data.pose.pose.position.z]) pos = np.array([odom_data.pose.pose.position.x, odom_data.pose.pose.position.y, odom_data.pose.pose.position.z])
goal_w = (self.goal - pos) / np.linalg.norm(self.goal - pos) goal_w = (self.goal - pos) / np.linalg.norm(self.goal - pos)
goal_b = np.dot(np.linalg.inv(self.Rotation_wc), goal_w) goal_b = np.dot(Rotation_cw, goal_w)
vel_acc = np.concatenate((vel_b, acc_b), axis=0) vel_acc = np.concatenate((vel_b, acc_b), axis=0)
vel_acc_norm = self.normalize_obs(vel_acc[np.newaxis, :]) vel_acc_norm = self.normalize_obs(vel_acc[np.newaxis, :])
@@ -154,8 +155,7 @@ class YopoNet:
if self.verbose: if self.verbose:
self.time_interpolation = self.time_interpolation + (time.time() - start) self.time_interpolation = self.time_interpolation + (time.time() - start)
self.count_interpolation = self.count_interpolation + 1 self.count_interpolation = self.count_interpolation + 1
print("Time Consuming: interpolation:", self.time_interpolation / self.count_interpolation) print(f"Time Consuming: depth-interpolation: {1000 * self.time_interpolation / self.count_interpolation:.2f}ms")
# cv2.imshow("1", depth_[0][0]) # cv2.imshow("1", depth_[0][0])
# cv2.waitKey(1) # cv2.waitKey(1)
self.depth = depth_.astype(np.float32) self.depth = depth_.astype(np.float32)
@@ -164,8 +164,7 @@ class YopoNet:
# TODO: Move the test_policy to callback_depth directly? # TODO: Move the test_policy to callback_depth directly?
def test_policy(self, _timer): def test_policy(self, _timer):
if self.new_depth and self.new_odom: if self.new_depth and self.new_odom:
self.new_odom = False self.new_odom, self.new_depth = False, False
self.new_depth = False
obs = self.process_odom() obs = self.process_odom()
odom_sec = self.odom.header.stamp.to_sec() odom_sec = self.odom.header.stamp.to_sec()
@@ -176,49 +175,29 @@ class YopoNet:
obs_norm_input = obs_norm_input.to(self.device, non_blocking=True) obs_norm_input = obs_norm_input.to(self.device, non_blocking=True)
# torch.cuda.synchronize() # torch.cuda.synchronize()
# forward
if self.use_trt: # TensorRT (inference speed increased by 10x)
time1 = time.time() time1 = time.time()
trt_output = self.policy(depth, obs_norm_input) # Forward (TensorRT: inference speed increased by 5x)
with torch.no_grad():
network_output = self.policy(depth, obs_norm_input)
network_output = network_output.cpu().numpy() # torch.cuda.synchronize() is not needed here
time2 = time.time() time2 = time.time()
endstate_pred, score_pred = self.trt_process(trt_output, return_all_preds=self.visualize) # Replacing PyTorch operation on CUDA with NumPy operation on CPU (speed increased by 10x)
endstate_pred = endstate_pred.squeeze() endstate_pred, score_pred = self.process_output(network_output, return_all_preds=self.visualize)
time3 = time.time() time3 = time.time()
else:
time1 = time.time()
endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input, return_all_preds=self.visualize)
endstate_pred = endstate_pred.cpu().numpy().squeeze()
score_pred = score_pred.cpu().numpy()
time2 = time3 = time.time()
# Transform the prediction(body frame) to the world frame with the attitude in inference # Vectorization: transform the prediction(P V A in body frame) to the world frame with the attitude (without the position)
# Replacing PyTorch calculations on CUDA with NumPy calculations on the CPU (speed increased by 10x) endstate_c = endstate_pred.T.reshape(-1, 3, 3)
endstate_b = endstate_pred endstate_w = np.matmul(self.Rotation_wc, endstate_c)
endstate_w = np.zeros_like(endstate_b) endstate_w = endstate_w.reshape(-1, 9).T
traj_num = self.lattice_space.horizon_num * self.lattice_space.vertical_num if self.visualize else 1
Pb, Vb, Ab = [np.zeros((3, traj_num)) for _ in range(3)]
for i in range(3):
Pb[i] = endstate_b[3 * i]
Vb[i] = endstate_b[3 * i + 1]
Ab[i] = endstate_b[3 * i + 2]
# pos_actual = np.array([self.odom.pose.pose.position.x,
# self.odom.pose.pose.position.y,
# self.odom.pose.pose.position.z])
Pw = np.dot(self.Rotation_wc, Pb) # + np.tile(pos_actual, (15, 1)).T
Vw = np.dot(self.Rotation_wc, Vb)
Aw = np.dot(self.Rotation_wc, Ab)
for i in range(3):
endstate_w[3 * i] = Pw[i]
endstate_w[3 * i + 1] = Vw[i]
endstate_w[3 * i + 2] = Aw[i]
if self.verbose: if self.verbose:
self.time_prepare = self.time_prepare + (time1 - time0) self.time_prepare = self.time_prepare + (time1 - time0)
self.time_forward = self.time_forward + (time2 - time1) self.time_forward = self.time_forward + (time2 - time1)
self.time_process = self.time_process + (time3 - time2) self.time_process = self.time_process + (time3 - time2)
self.count = self.count + 1 self.count = self.count + 1
print("Time Consuming: prepare:", self.time_prepare / self.count, "; forward:", self.time_forward / self.count, print(f"Time Consuming: data-prepare: {1000 * self.time_prepare / self.count:.2f}ms; "
"; process:", self.time_process / self.count) f"network-inference: {1000 * self.time_forward / self.count:.2f}ms; "
f"post-process: {1000 * self.time_process / self.count:.2f}ms")
# publish # publish
if not self.visualize: if not self.visualize:
@@ -232,7 +211,7 @@ class YopoNet:
endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms) endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms)
self.endstate_pub.publish(endstate_pred_to_pub) self.endstate_pub.publish(endstate_pred_to_pub)
# visualization # visualization
endstate_score_preds = np.concatenate((endstate_w, score_pred), axis=0) endstate_score_preds = np.vstack([endstate_w, score_pred])
all_endstate_pred = Float32MultiArray(data=endstate_score_preds.T.reshape(-1)) all_endstate_pred = Float32MultiArray(data=endstate_score_preds.T.reshape(-1))
all_endstate_pred.layout.dim.append(MultiArrayDimension()) all_endstate_pred.layout.dim.append(MultiArrayDimension())
all_endstate_pred.layout.dim[0].size = endstate_score_preds.shape[1] all_endstate_pred.layout.dim[0].size = endstate_score_preds.shape[1]
@@ -246,28 +225,25 @@ class YopoNet:
elif not self.new_odom: elif not self.new_odom:
self.odom_ref_init = False self.odom_ref_init = False
def trt_process(self, input_tensor: torch.Tensor, return_all_preds=False) -> torch.Tensor: def process_output(self, network_output, return_all_preds=False):
batch_size = input_tensor.shape[0] if network_output.shape[0] != 1:
input_tensor = input_tensor.cpu().numpy() raise ValueError("batch of output values must be 1 in test!")
input_tensor = input_tensor.reshape(batch_size, 10, self.lattice_space.horizon_num * self.lattice_space.vertical_num) network_output = network_output.reshape(10, self.lattice_space.horizon_num * self.lattice_space.vertical_num)
endstate_pred = input_tensor[:, 0:9, :] endstate_pred = network_output[0:9, :]
score_pred = input_tensor[:, 9, :] score_pred = network_output[9, :]
if not return_all_preds: if not return_all_preds:
endstate_prediction = np.zeros((batch_size, 9)) action_id = np.argmin(score_pred)
score_prediction = np.zeros((batch_size, 1))
for i in range(0, batch_size):
action_id = np.argmin(score_pred[i])
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id
endstate_prediction[i] = self.pred_to_endstate(np.expand_dims(endstate_pred[i, :, action_id], axis=0), lattice_id) endstate_prediction = self.pred_to_endstate(endstate_pred[:, action_id], lattice_id)
score_prediction[i] = score_pred[i, action_id] endstate_prediction = endstate_prediction[:, np.newaxis]
score_prediction = score_pred[action_id]
else: else:
endstate_prediction = np.zeros_like(endstate_pred) endstate_prediction = np.zeros_like(endstate_pred)
score_prediction = score_pred score_prediction = score_pred
for i in range(0, self.lattice_space.horizon_num * self.lattice_space.vertical_num): for i in range(self.lattice_space.horizon_num * self.lattice_space.vertical_num):
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i
endstate = self.pred_to_endstate(endstate_pred[:, :, i], lattice_id) endstate_prediction[:, i] = self.pred_to_endstate(endstate_pred[:, i], lattice_id)
endstate_prediction[:, :, i] = endstate
return endstate_prediction, score_prediction return endstate_prediction, score_prediction
@@ -276,45 +252,40 @@ class YopoNet:
convert the observation from body frame to primitive frame, convert the observation from body frame to primitive frame,
and then concatenate it with the depth features (to ensure the translational invariance) and then concatenate it with the depth features (to ensure the translational invariance)
""" """
obs_return = np.ones((obs.shape[0], self.lattice_space.vertical_num, self.lattice_space.horizon_num, obs.shape[1]), dtype=np.float32) if obs.shape[0] != 1:
raise ValueError("batch of input observations must be 1 in test!")
obs_return = np.ones((obs.shape[0], obs.shape[1], self.lattice_space.vertical_num, self.lattice_space.horizon_num), dtype=np.float32)
id = 0 id = 0
v_b = obs[:, 0:3] obs_reshaped = obs.reshape(3, 3)
a_b = obs[:, 3:6]
g_b = obs[:, 6:9]
for i in range(self.lattice_space.vertical_num - 1, -1, -1): for i in range(self.lattice_space.vertical_num - 1, -1, -1):
for j in range(self.lattice_space.horizon_num - 1, -1, -1): for j in range(self.lattice_space.horizon_num - 1, -1, -1):
Rbp = self.lattice_primitive.getRotation(id) Rbp = self.lattice_primitive.getRotation(id)
v_p = np.dot(Rbp.T, v_b.T).T obs_return_reshaped = np.dot(obs_reshaped, Rbp)
a_p = np.dot(Rbp.T, a_b.T).T obs_return[:, :, i, j] = obs_return_reshaped.reshape(9)
g_p = np.dot(Rbp.T, g_b.T).T
obs_return[:, i, j, 0:3] = v_p
obs_return[:, i, j, 3:6] = a_p
obs_return[:, i, j, 6:9] = g_p
# obs_return[:, i, j, 0:6] = self.normalize_obs(obs_return[:, i, j, 0:6])
id = id + 1 id = id + 1
obs_return = np.transpose(obs_return, [0, 3, 1, 2])
return torch.from_numpy(obs_return) return torch.from_numpy(obs_return)
def pred_to_endstate(self, endstate_pred: np.ndarray, id: int): def pred_to_endstate(self, endstate_pred: np.ndarray, id: int):
""" """
Transform the predicted state to the body frame. Transform the predicted state to the body frame.
""" """
delta_yaw = endstate_pred[:, 0] * self.lattice_primitive.yaw_diff delta_yaw = endstate_pred[0] * self.lattice_primitive.yaw_diff
delta_pitch = endstate_pred[:, 1] * self.lattice_primitive.pitch_diff delta_pitch = endstate_pred[1] * self.lattice_primitive.pitch_diff
radio = endstate_pred[:, 2] * self.lattice_space.radio_range + self.lattice_space.radio_range radio = endstate_pred[2] * self.lattice_space.radio_range + self.lattice_space.radio_range
yaw, pitch = self.lattice_primitive.getAngleLattice(id) yaw, pitch = self.lattice_primitive.getAngleLattice(id)
endstate_x = np.cos(pitch + delta_pitch) * np.cos(yaw + delta_yaw) * radio endstate_x = np.cos(pitch + delta_pitch) * np.cos(yaw + delta_yaw) * radio
endstate_y = np.cos(pitch + delta_pitch) * np.sin(yaw + delta_yaw) * radio endstate_y = np.cos(pitch + delta_pitch) * np.sin(yaw + delta_yaw) * radio
endstate_z = np.sin(pitch + delta_pitch) * radio endstate_z = np.sin(pitch + delta_pitch) * radio
endstate_p = np.stack((endstate_x, endstate_y, endstate_z), axis=1) endstate_p = np.array((endstate_x, endstate_y, endstate_z))
endstate_vp = endstate_pred[:, 3:6] * self.lattice_space.vel_max endstate_vp = endstate_pred[3:6] * self.lattice_space.vel_max
endstate_ap = endstate_pred[:, 6:9] * self.lattice_space.acc_max endstate_ap = endstate_pred[6:9] * self.lattice_space.acc_max
Rbp = self.lattice_primitive.getRotation(id) Rpb = self.lattice_primitive.getRotation(id).T
endstate_vb = np.matmul(Rbp, endstate_vp.T).T endstate_vb = np.matmul(endstate_vp, Rpb)
endstate_ab = np.matmul(Rbp, endstate_ap.T).T endstate_ab = np.matmul(endstate_ap, Rpb)
endstate = np.concatenate((endstate_p, endstate_vb, endstate_ab), axis=1) endstate = np.concatenate((endstate_p, endstate_vb, endstate_ab))
endstate[:, [0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[:, [0, 3, 6, 1, 4, 7, 2, 5, 8]] endstate[[0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[[0, 3, 6, 1, 4, 7, 2, 5, 8]]
return endstate return endstate
def normalize_obs(self, vel_acc): def normalize_obs(self, vel_acc):
@@ -326,11 +297,8 @@ class YopoNet:
depth = np.zeros(shape=[1, 1, self.height, self.width], dtype=np.float32) depth = np.zeros(shape=[1, 1, self.height, self.width], dtype=np.float32)
obs = np.zeros(shape=[1, 9], dtype=np.float32) obs = np.zeros(shape=[1, 9], dtype=np.float32)
obs_input = self.prepare_input_observation(obs) obs_input = self.prepare_input_observation(obs)
if self.use_trt: network_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device))
trt_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device)) self.process_output(network_output.cpu().numpy(), return_all_preds=True)
self.trt_process(trt_output, return_all_preds=True)
else:
self.policy.predict(torch.from_numpy(depth).to(self.device), obs_input.to(self.device), return_all_preds=True)
def parser(): def parser():
@@ -339,6 +307,7 @@ def parser():
parser.add_argument("--trial", type=int, default=1, help="trial number") parser.add_argument("--trial", type=int, default=1, help="trial number")
parser.add_argument("--epoch", type=int, default=0, help="epoch number") parser.add_argument("--epoch", type=int, default=0, help="epoch number")
parser.add_argument("--iter", type=int, default=0, help="iter number") parser.add_argument("--iter", type=int, default=0, help="iter number")
parser.add_argument("--trt_file", type=str, default='yopo_trt.pth', help="tensorrt filename")
return parser return parser
@@ -348,9 +317,10 @@ def main():
args = parser().parse_args() args = parser().parse_args()
rsg_root = os.path.dirname(os.path.abspath(__file__)) rsg_root = os.path.dirname(os.path.abspath(__file__))
if args.use_tensorrt: if args.use_tensorrt:
weight = "yopo_trt.pth" weight = args.trt_file
else: else:
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter) weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
print("load weight from:", weight)
settings = {'use_tensorrt': args.use_tensorrt, settings = {'use_tensorrt': args.use_tensorrt,
'network_frequency': 30, 'network_frequency': 30,

View File

@@ -46,6 +46,7 @@ def parser():
parser.add_argument("--trial", type=int, default=1, help="trial number") parser.add_argument("--trial", type=int, default=1, help="trial number")
parser.add_argument("--epoch", type=int, default=0, help="epoch number") parser.add_argument("--epoch", type=int, default=0, help="epoch number")
parser.add_argument("--iter", type=int, default=0, help="iter number") parser.add_argument("--iter", type=int, default=0, help="iter number")
parser.add_argument("--fp16_mode", type=int, default=1, help="fp16 or fp32")
parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name") parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name")
return parser return parser
@@ -75,6 +76,7 @@ if __name__ == "__main__":
saved_variables = torch.load(weight, map_location=device) saved_variables = torch.load(weight, map_location=device)
model.policy.load_state_dict(saved_variables["state_dict"], strict=False) model.policy.load_state_dict(saved_variables["state_dict"], strict=False)
model.policy.set_training_mode(False) model.policy.set_training_mode(False)
torch.set_grad_enabled(False)
lattice_space = saved_variables["data"]["lattice_space"] lattice_space = saved_variables["data"]["lattice_space"]
lattice_primitive = saved_variables["data"]["lattice_primitive"] lattice_primitive = saved_variables["data"]["lattice_primitive"]
@@ -86,7 +88,7 @@ if __name__ == "__main__":
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive) obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
depth_in = torch.from_numpy(depth).cuda() depth_in = torch.from_numpy(depth).cuda()
obs_in = torch.from_numpy(obs_input).cuda() obs_in = torch.from_numpy(obs_input).cuda()
model_trt = torch2trt(model.policy, [depth_in, obs_in]) model_trt = torch2trt(model.policy, [depth_in, obs_in], fp16_mode=args.fp16_mode)
torch.save(model_trt.state_dict(), args.filename) torch.save(model_trt.state_dict(), args.filename)
print("TensorRT Transfer Finish!") print("TensorRT Transfer Finish!")
@@ -95,18 +97,27 @@ if __name__ == "__main__":
# model_trt.load_state_dict(torch.load('yopo_trt.pth')) # model_trt.load_state_dict(torch.load('yopo_trt.pth'))
print("Evaluation...") print("Evaluation...")
# warm up... # Warm Up...
y_trt = model_trt(depth_in, obs_in) y_trt = model_trt(depth_in, obs_in)
y = model.policy(depth_in, obs_in) y = model.policy(depth_in, obs_in)
torch.cuda.synchronize()
# PyTorch Latency
torch_start = time.time() torch_start = time.time()
y = model.policy(depth_in, obs_in) y = model.policy(depth_in, obs_in)
torch.cuda.synchronize()
torch_end = time.time() torch_end = time.time()
# TensorRT Latency
trt_start = time.time()
y_trt = model_trt(depth_in, obs_in) y_trt = model_trt(depth_in, obs_in)
torch.cuda.synchronize()
trt_end = time.time() trt_end = time.time()
# Transfer Error
error = torch.mean(torch.abs(y - y_trt)) error = torch.mean(torch.abs(y - y_trt))
print("Torch Latency: ", 1000 * (torch_end - torch_start),
"ms, TensorRT Latency: ", 1000 * (trt_end - torch_end), print(f"Torch Latency: {1000 * (torch_end - torch_start):.3f} ms, "
"ms, Transfer Error: ", error.item()) f"TensorRT Latency: {1000 * (trt_end - trt_start):.3f} ms, "
f"Transfer Error: {error.item():.8f}")