Initial Commit (tested training, testing, and TRT conversion)
This commit is contained in:
101
run/yopo_trt_transfer.py
Normal file
101
run/yopo_trt_transfer.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
将yopo模型转换为Tensorrt
|
||||
prepare:
|
||||
1 pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com
|
||||
2 git clone https://github.com/NVIDIA-AI-IOT/torch2trt
|
||||
cd torch2trt
|
||||
python setup.py install
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch2trt import torch2trt
|
||||
from flightgym import QuadrotorEnv_v1
|
||||
from ruamel.yaml import YAML, RoundTripDumper, dump
|
||||
from flightpolicy.envs import vec_env_wrapper as wrapper
|
||||
from flightpolicy.yopo.yopo_algorithm import YopoAlgorithm
|
||||
|
||||
|
||||
def prapare_input_observation(obs, lattice_space, lattice_primitive):
|
||||
obs_return = np.ones(
|
||||
(obs.shape[0], lattice_space.vertical_num, lattice_space.horizon_num, obs.shape[1]),
|
||||
dtype=np.float32)
|
||||
id = 0
|
||||
v_b = obs[:, 0:3]
|
||||
a_b = obs[:, 3:6]
|
||||
g_b = obs[:, 6:9]
|
||||
for i in range(lattice_space.vertical_num - 1, -1, -1):
|
||||
for j in range(lattice_space.horizon_num - 1, -1, -1):
|
||||
Rbp = lattice_primitive.getRotation(id)
|
||||
v_p = np.dot(Rbp.T, v_b.T).T
|
||||
a_p = np.dot(Rbp.T, a_b.T).T
|
||||
g_p = np.dot(Rbp.T, g_b.T).T
|
||||
obs_return[:, i, j, 0:3] = v_p
|
||||
obs_return[:, i, j, 3:6] = a_p
|
||||
obs_return[:, i, j, 6:9] = g_p
|
||||
id = id + 1
|
||||
obs_return = np.transpose(obs_return, [0, 3, 1, 2])
|
||||
return obs_return
|
||||
|
||||
|
||||
def parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--trial", type=int, default=1, help="trial number")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
|
||||
parser.add_argument("--iter", type=int, default=0, help="iter number")
|
||||
parser.add_argument("--dir", type=str, default='yopo_trt.pth', help="output file name")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = parser().parse_args()
|
||||
# load configurations
|
||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
||||
cfg["env"]["num_envs"] = 1
|
||||
cfg["env"]["supervised"] = False
|
||||
cfg["env"]["imitation"] = False
|
||||
cfg["env"]["render"] = False
|
||||
|
||||
# create environment
|
||||
train_env = QuadrotorEnv_v1(dump(cfg, Dumper=RoundTripDumper), False)
|
||||
train_env = wrapper.FlightEnvVec(train_env)
|
||||
model = YopoAlgorithm(env=train_env,
|
||||
policy_kwargs=dict(
|
||||
activation_fn=torch.nn.ReLU,
|
||||
net_arch=[256, 256],
|
||||
hidden_state=64
|
||||
))
|
||||
|
||||
rsg_root = os.path.dirname(os.path.abspath(__file__))
|
||||
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
|
||||
device = torch.device("cuda")
|
||||
saved_variables = torch.load(weight, map_location=device)
|
||||
model.policy.load_state_dict(saved_variables["state_dict"], strict=False)
|
||||
model.policy.set_training_mode(False)
|
||||
|
||||
lattice_space = saved_variables["data"]["lattice_space"]
|
||||
lattice_primitive = saved_variables["data"]["lattice_primitive"]
|
||||
|
||||
# The inputs should be consistent with training
|
||||
depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32)
|
||||
obs = np.zeros(shape=[1, 9], dtype=np.float32)
|
||||
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
|
||||
depth_in = torch.from_numpy(depth).cuda()
|
||||
obs_in = torch.from_numpy(obs_input).cuda()
|
||||
model_trt = torch2trt(model.policy, [depth_in, obs_in])
|
||||
torch.save(model_trt.state_dict(), args.dir)
|
||||
|
||||
# from torch2trt import TRTModule
|
||||
# model_trt = TRTModule()
|
||||
# model_trt.load_state_dict(torch.load('yopo_trt.pth'))
|
||||
|
||||
y_trt = model_trt(depth_in, obs_in)
|
||||
y = model.policy(depth_in, obs_in)
|
||||
error = torch.mean(torch.abs(y - y_trt))
|
||||
print("transfer error: ", error)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user