simplify inference node, vectorize NumPy operations, fix timing bug.

This commit is contained in:
TJU_Lu
2024-12-17 21:10:46 +08:00
parent 59364bef31
commit 35cd195a10
2 changed files with 77 additions and 96 deletions

View File

@@ -46,6 +46,7 @@ def parser():
parser.add_argument("--trial", type=int, default=1, help="trial number")
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
parser.add_argument("--iter", type=int, default=0, help="iter number")
parser.add_argument("--fp16_mode", type=int, default=1, help="fp16 or fp32")
parser.add_argument("--filename", type=str, default='yopo_trt.pth', help="output file name")
return parser
@@ -75,6 +76,7 @@ if __name__ == "__main__":
saved_variables = torch.load(weight, map_location=device)
model.policy.load_state_dict(saved_variables["state_dict"], strict=False)
model.policy.set_training_mode(False)
torch.set_grad_enabled(False)
lattice_space = saved_variables["data"]["lattice_space"]
lattice_primitive = saved_variables["data"]["lattice_primitive"]
@@ -86,7 +88,7 @@ if __name__ == "__main__":
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
depth_in = torch.from_numpy(depth).cuda()
obs_in = torch.from_numpy(obs_input).cuda()
model_trt = torch2trt(model.policy, [depth_in, obs_in])
model_trt = torch2trt(model.policy, [depth_in, obs_in], fp16_mode=args.fp16_mode)
torch.save(model_trt.state_dict(), args.filename)
print("TensorRT Transfer Finish!")
@@ -95,18 +97,27 @@ if __name__ == "__main__":
# model_trt.load_state_dict(torch.load('yopo_trt.pth'))
print("Evaluation...")
# warm up...
# Warm Up...
y_trt = model_trt(depth_in, obs_in)
y = model.policy(depth_in, obs_in)
torch.cuda.synchronize()
# PyTorch Latency
torch_start = time.time()
y = model.policy(depth_in, obs_in)
torch.cuda.synchronize()
torch_end = time.time()
# TensorRT Latency
trt_start = time.time()
y_trt = model_trt(depth_in, obs_in)
torch.cuda.synchronize()
trt_end = time.time()
# Transfer Error
error = torch.mean(torch.abs(y - y_trt))
print("Torch Latency: ", 1000 * (torch_end - torch_start),
"ms, TensorRT Latency: ", 1000 * (trt_end - torch_end),
"ms, Transfer Error: ", error.item())
print(f"Torch Latency: {1000 * (torch_end - torch_start):.3f} ms, "
f"TensorRT Latency: {1000 * (trt_end - trt_start):.3f} ms, "
f"Transfer Error: {error.item():.8f}")