Initial Commit (tested training, testing, and TRT conversion)
This commit is contained in:
119
run/data_collection_realworld.py
Normal file
119
run/data_collection_realworld.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
# 收集实飞数据,记录位置、姿态、图像,用于离线fine-tuning (保存至save_dir)
|
||||
# 注意: 由于里程计漂移,可能utils/pointcloud_clip需要对地图进行微调,需对无人机位置和yaw, pitch, roll做相同的变换
|
||||
# 注意保证地图和里程计处于同一坐标系,同时录包+保存地图
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time, os, sys
|
||||
from cv_bridge import CvBridge, CvBridgeError
|
||||
import rospy
|
||||
from sensor_msgs.msg import Image
|
||||
from nav_msgs.msg import Odometry
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
depth_img = np.zeros([270, 480])
|
||||
pos = np.array([0, 0, 0])
|
||||
quat = np.array([1, 0, 0, 0])
|
||||
positions = []
|
||||
quaternions = []
|
||||
frame_id = 0
|
||||
new_depth = False
|
||||
new_odom = False
|
||||
first_frame = True
|
||||
last_time = time.time()
|
||||
save_dir = os.environ["FLIGHTMARE_PATH"] + "/run/depth_realworld"
|
||||
label_path = save_dir + "/label.npz"
|
||||
if not os.path.exists(save_dir):
|
||||
os.mkdir(save_dir)
|
||||
# Due to odometry drift, the map is adjusted, and the drone's position is also adjusted accordingly.
|
||||
R_no = Rotation.from_euler('ZYX', [15, 3, 0.0], degrees=True) # yaw, pitch, roll
|
||||
translation_no = np.array([0, 0, 2])
|
||||
|
||||
|
||||
def callback_odometry(data):
|
||||
# NWU
|
||||
global pos, quat, new_odom, R_no, translation_no
|
||||
p_ob = np.array([[data.pose.pose.position.x],
|
||||
[data.pose.pose.position.y],
|
||||
[data.pose.pose.position.z]])
|
||||
q_ob = np.array([data.pose.pose.orientation.x,
|
||||
data.pose.pose.orientation.y,
|
||||
data.pose.pose.orientation.z,
|
||||
data.pose.pose.orientation.w])
|
||||
R_ob = Rotation.from_quat(q_ob) # old->body (xyzw)
|
||||
quat_xyzw = (R_no * R_ob).as_quat() # new->body (xyzw)
|
||||
quat = np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]])
|
||||
pos = np.squeeze(np.dot(R_no.as_matrix(), p_ob)) + translation_no
|
||||
new_odom = True
|
||||
|
||||
|
||||
def callback_depth(data):
|
||||
global depth_img, new_depth
|
||||
max_dis = 20.0
|
||||
min_dis = 0.03
|
||||
height = 270
|
||||
width = 480
|
||||
scale = 0.001
|
||||
bridge = CvBridge()
|
||||
try:
|
||||
depth_ = bridge.imgmsg_to_cv2(data, "32FC1")
|
||||
except:
|
||||
print("CV_bridge ERROR: Your ros and python path has something wrong!")
|
||||
|
||||
if depth_.shape[0] != height or depth_.shape[1] != width:
|
||||
depth_ = cv2.resize(depth_, (width, height), interpolation=cv2.INTER_NEAREST)
|
||||
depth_ = np.minimum(depth_ * scale, max_dis) / max_dis
|
||||
|
||||
try:
|
||||
nan_mask = np.isnan(depth_) | (depth_ < min_dis)
|
||||
depth_ = cv2.inpaint(np.uint8(depth_ * 255), np.uint8(nan_mask), 3, cv2.INPAINT_NS)
|
||||
depth_ = depth_.astype(np.float32) / 255.0
|
||||
except:
|
||||
print("Interpolation failed")
|
||||
|
||||
# Not necessary, but encountered some inexplicable errors previously, so temporarily kept.
|
||||
if np.sum(np.isnan(depth_)) > 0:
|
||||
depth_[np.isnan(depth_)] = 0
|
||||
print("WARN: Have NAN values in depth image")
|
||||
|
||||
depth_img = depth_.copy()
|
||||
new_depth = True
|
||||
|
||||
|
||||
def save_data(_timer):
|
||||
global pos, quat, new_odom, depth_img, new_depth, last_time, first_frame
|
||||
global save_dir, label_path, frame_id, positions, quaternions
|
||||
if not (new_odom and new_depth):
|
||||
if not first_frame and time.time() - last_time > 1:
|
||||
np.savez(
|
||||
label_path,
|
||||
positions=np.asarray(positions),
|
||||
quaternions=np.asarray(quaternions),
|
||||
)
|
||||
print("Record Done!")
|
||||
sys.exit()
|
||||
return
|
||||
new_odom, new_depth = False, False
|
||||
|
||||
image_path = save_dir + "/img_" + str(frame_id) + ".tif"
|
||||
cv2.imwrite(image_path, depth_img)
|
||||
positions.append(pos)
|
||||
quaternions.append(quat)
|
||||
|
||||
last_time = time.time()
|
||||
first_frame = False
|
||||
frame_id = frame_id + 1
|
||||
|
||||
|
||||
def main():
|
||||
rospy.init_node('data_collect', anonymous=False)
|
||||
odom_ref_sub = rospy.Subscriber("/odometry/imu", Odometry, callback_odometry, queue_size=1)
|
||||
depth_sub = rospy.Subscriber("/camera/depth/image_rect_raw", Image, callback_depth, queue_size=1)
|
||||
timer = rospy.Timer(rospy.Duration(0.033), save_data)
|
||||
print("Data Collection Node Ready!")
|
||||
rospy.spin()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
84
run/data_collection_simulation.py
Normal file
84
run/data_collection_simulation.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from flightgym import QuadrotorEnv_v1
|
||||
from flightpolicy.envs import vec_env_wrapper as wrapper
|
||||
from ruamel.yaml import YAML, RoundTripDumper, dump
|
||||
|
||||
|
||||
def configure_random_seed(seed, env=None):
|
||||
if env is not None:
|
||||
env.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
def parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||
parser.add_argument("--num_each_env", type=int, default=10000, help="num of images to save in each env")
|
||||
parser.add_argument("--num_env", type=int, default=10, help="num of env to change")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = parser().parse_args()
|
||||
|
||||
configure_random_seed(args.seed)
|
||||
|
||||
# load configurations
|
||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
||||
cfg["env"]["num_envs"] = 1
|
||||
cfg["env"]["num_threads"] = 1
|
||||
cfg["env"]["render"] = True
|
||||
cfg["env"]["supervised"] = False
|
||||
cfg["env"]["imitation"] = False
|
||||
|
||||
os.system(os.environ["FLIGHTMARE_PATH"] + "/flightrender/RPG_Flightmare/flightmare.x86_64 &")
|
||||
env = QuadrotorEnv_v1(dump(cfg, Dumper=RoundTripDumper), False)
|
||||
env = wrapper.FlightEnvVec(env)
|
||||
env.connectUnity()
|
||||
|
||||
iteration = args.num_each_env
|
||||
epoch = args.num_env
|
||||
|
||||
home_dir = os.environ["FLIGHTMARE_PATH"] + cfg["env"]["dataset_path"]
|
||||
if not os.path.exists(home_dir):
|
||||
os.mkdir(home_dir)
|
||||
|
||||
for epoch_i in range(epoch):
|
||||
spacing = cfg["unity"]["avg_tree_spacing"]
|
||||
env.spawnTreesAndSavePointcloud(epoch_i, spacing)
|
||||
env.setMapID(np.array([-1]))
|
||||
env.reset(random=True)
|
||||
|
||||
positions = np.zeros([iteration, 3], dtype=np.float32)
|
||||
quaternions = np.zeros([iteration, 4], dtype=np.float32)
|
||||
|
||||
save_dir = os.environ["FLIGHTMARE_PATH"] + cfg["env"]["dataset_path"] + str(epoch_i) + "/"
|
||||
label_path = save_dir + "/label.npz"
|
||||
if not os.path.exists(save_dir):
|
||||
os.mkdir(save_dir)
|
||||
|
||||
for frame_id in tqdm(range(iteration)):
|
||||
image_path = save_dir + "/img_" + str(frame_id) + ".tif"
|
||||
observation = env.reset()
|
||||
positions[frame_id, :] = observation[0, 0:3]
|
||||
quaternions[frame_id, :] = observation[0, 9:]
|
||||
depth = env.getDepthImage(resize=False)
|
||||
cv2.imwrite(image_path, depth[0][0])
|
||||
|
||||
np.savez(
|
||||
label_path,
|
||||
positions=positions,
|
||||
quaternions=quaternions,
|
||||
)
|
||||
|
||||
env.disconnectUnity()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
118
run/run_yopo.py
Normal file
118
run/run_yopo.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from flightgym import QuadrotorEnv_v1
|
||||
from ruamel.yaml import YAML, RoundTripDumper, dump
|
||||
from flightpolicy.envs import vec_env_wrapper as wrapper
|
||||
from flightpolicy.yopo.yopo_algorithm import YopoAlgorithm
|
||||
|
||||
|
||||
def configure_random_seed(seed, env=None):
|
||||
if env is not None:
|
||||
env.seed(seed)
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# some cudnn methods can be random even after fixing the seed unless you tell it to be deterministic
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
||||
parser.add_argument("--train", type=int, default=1, help="train or evaluate the policy?")
|
||||
parser.add_argument("--render", type=int, default=0, help="render with Unity?")
|
||||
parser.add_argument("--trial", type=int, default=1, help="trial number")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
|
||||
parser.add_argument("--iter", type=int, default=0, help="iter number")
|
||||
parser.add_argument("--pretrained", type=int, default=0, help="use pre-trained model?")
|
||||
parser.add_argument("--supervised", type=int, default=1, help="supervised learning?")
|
||||
parser.add_argument("--imitation", type=int, default=0, help="imitation learning?")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = parser().parse_args()
|
||||
|
||||
# load configurations
|
||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
||||
cfg["env"]["supervised"] = bool(args.supervised)
|
||||
cfg["env"]["imitation"] = bool(args.imitation)
|
||||
if not args.train:
|
||||
cfg["env"]["num_envs"] = 1
|
||||
cfg["env"]["render"] = bool(args.render)
|
||||
if args.render:
|
||||
cfg["env"]["ply_path"] = "/flightrender/RPG_Flightmare/pointcloud_data/" # change the paths during test or imitation
|
||||
if not os.path.exists(os.environ["FLIGHTMARE_PATH"] + cfg["env"]["ply_path"]):
|
||||
os.mkdir(os.environ["FLIGHTMARE_PATH"] + cfg["env"]["ply_path"])
|
||||
os.system(os.environ["FLIGHTMARE_PATH"] + "/flightrender/RPG_Flightmare/flightmare.x86_64 &")
|
||||
|
||||
# create training environment
|
||||
train_env = QuadrotorEnv_v1(dump(cfg, Dumper=RoundTripDumper), False)
|
||||
train_env = wrapper.FlightEnvVec(train_env)
|
||||
|
||||
# set random seed
|
||||
configure_random_seed(args.seed, env=train_env)
|
||||
|
||||
# save the configuration and other files
|
||||
rsg_root = os.path.dirname(os.path.abspath(__file__))
|
||||
log_dir = rsg_root + "/saved"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
model = YopoAlgorithm(
|
||||
tensorboard_log=log_dir,
|
||||
env=train_env,
|
||||
is_imitation=args.imitation,
|
||||
learning_starts=10000, # How many samples are collected before starting imitation learning
|
||||
train_freq=200, # How many steps of data to collect from each environment per round
|
||||
gradient_steps=200, # How many steps to train per round
|
||||
change_env_freq=20, # How many rounds of "collect-train" to reset the tree (-1: not reset)
|
||||
learning_rate=1.5e-4, # Learning rate
|
||||
batch_size=cfg["env"]["num_envs"], # Equal to the number of environment, as gradients are from environments
|
||||
buffer_size=100000, # Buffer size
|
||||
loss_weight=[1.0, 10.0], # Weights for the costs of endstate and score
|
||||
unselect=0, # Proportion of trajectories not optimized in each sample
|
||||
policy_kwargs=dict(
|
||||
activation_fn=torch.nn.ReLU,
|
||||
net_arch=[256, 256],
|
||||
hidden_state=64
|
||||
),
|
||||
verbose=1,
|
||||
)
|
||||
|
||||
if args.render:
|
||||
train_env.connectUnity()
|
||||
spacing = cfg["unity"]["avg_tree_spacing"]
|
||||
train_env.spawnTreesAndSavePointcloud(0, spacing)
|
||||
train_env.setMapID(-np.ones((train_env.num_envs, 1)))
|
||||
train_env.reset(random=True)
|
||||
|
||||
if args.train:
|
||||
if args.pretrained:
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
|
||||
saved_variables = torch.load(weight, map_location=device)
|
||||
model.policy.load_state_dict(saved_variables["state_dict"], strict=False)
|
||||
print("use pretrained model ", weight)
|
||||
|
||||
if args.supervised:
|
||||
model.supervised_learning(epoch=int(50), log_interval=(100, 50000)) # How many batches to print and save
|
||||
|
||||
elif args.imitation:
|
||||
model.imitation_learning(total_timesteps=int(1 * 1e6), log_interval=(1, 40))
|
||||
|
||||
else:
|
||||
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
saved_variables = torch.load(weight, map_location=device)
|
||||
model.policy.load_state_dict(saved_variables["state_dict"], strict=False)
|
||||
model.test_policy(num_rollouts=20)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
run/saved/YOPO_1/Policy/epoch0_iter0.pth
Normal file
BIN
run/saved/YOPO_1/Policy/epoch0_iter0.pth
Normal file
Binary file not shown.
BIN
run/saved/YOPO_1/events.out.tfevents.1729323375.610.68573.0
Normal file
BIN
run/saved/YOPO_1/events.out.tfevents.1729323375.610.68573.0
Normal file
Binary file not shown.
388
run/test_yopo_ros.py
Normal file
388
run/test_yopo_ros.py
Normal file
@@ -0,0 +1,388 @@
|
||||
import rospy
|
||||
from sensor_msgs.msg import Image
|
||||
from nav_msgs.msg import Odometry
|
||||
from std_msgs.msg import Float32MultiArray, MultiArrayDimension
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
from cv_bridge import CvBridge
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
import time
|
||||
from ruamel.yaml import YAML
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
from flightpolicy.yopo.yopo_policy import YopoPolicy
|
||||
from flightpolicy.yopo.primitive_utils import LatticeParam, LatticePrimitive
|
||||
|
||||
try:
|
||||
from torch2trt import TRTModule
|
||||
except ImportError:
|
||||
print("tensorrt not found.")
|
||||
|
||||
|
||||
class YopoNet:
|
||||
def __init__(self, config, weight):
|
||||
self.config = config
|
||||
rospy.init_node('yopo_net', anonymous=False)
|
||||
# load params
|
||||
self.bridge = CvBridge()
|
||||
self.odom = Odometry()
|
||||
self.odom_ref = Odometry()
|
||||
self.height = self.config['img_height']
|
||||
self.width = self.config['img_width']
|
||||
self.depth = np.zeros((1, 1, self.config['img_height'], self.config['img_width']))
|
||||
self.goal = np.array(self.config['goal'])
|
||||
self.env = self.config['env']
|
||||
self.use_trt = self.config['use_tensorrt']
|
||||
self.verbose = self.config['verbose']
|
||||
self.visualize = self.config['visualize']
|
||||
self.Rotation_bc = R.from_euler('ZYX', [0, self.config['pitch_angle_deg'], 0], degrees=True).as_matrix()
|
||||
self.new_odom = False
|
||||
self.new_depth = False
|
||||
self.odom_ref_init = False
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/traj_opt.yaml", 'r'))
|
||||
self.lattice_space = LatticeParam(cfg)
|
||||
self.lattice_primitive = LatticePrimitive(self.lattice_space)
|
||||
|
||||
# eval
|
||||
self.time_forward = 0.0
|
||||
self.time_process = 0.0
|
||||
self.time_prepare = 0.0
|
||||
self.time_interpolation = 0.0
|
||||
self.count = 0
|
||||
self.count_interpolation = 0
|
||||
|
||||
# Load Network
|
||||
if self.use_trt:
|
||||
self.policy = TRTModule()
|
||||
self.policy.load_state_dict(torch.load(weight))
|
||||
else:
|
||||
saved_variables = torch.load(weight, map_location=self.device)
|
||||
saved_variables["data"]["lattice_space"] = self.lattice_space
|
||||
saved_variables["data"]["lattice_primitive"] = self.lattice_primitive
|
||||
self.policy = YopoPolicy(device=self.device, **saved_variables["data"])
|
||||
self.policy.load_state_dict(saved_variables["state_dict"], strict=False)
|
||||
self.policy.to(self.device)
|
||||
self.policy.set_training_mode(False)
|
||||
torch.set_grad_enabled(False)
|
||||
self.warm_up()
|
||||
|
||||
# ros publisher
|
||||
odom_topic = self.config['odom_topic']
|
||||
depth_topic = self.config['depth_topic']
|
||||
self.endstate_pub = rospy.Publisher("/yopo_net/pred_endstate", Float32MultiArray, queue_size=1)
|
||||
self.all_endstate_pub = rospy.Publisher("/yopo_net/pred_endstates", Float32MultiArray, queue_size=1)
|
||||
self.goal_pub = rospy.Publisher("/yopo_net/goal", Float32MultiArray, queue_size=1)
|
||||
# ros subscriber
|
||||
self.odom_sub = rospy.Subscriber(odom_topic, Odometry, self.callback_odometry, queue_size=1, tcp_nodelay=True)
|
||||
self.odom_ref_sub = rospy.Subscriber("/juliett/state_ref/odom", Odometry, self.callback_odometry_ref,
|
||||
queue_size=1, tcp_nodelay=True)
|
||||
self.depth_sub = rospy.Subscriber(depth_topic, Image, self.callback_depth, queue_size=1, tcp_nodelay=True)
|
||||
self.goal_sub = rospy.Subscriber("/move_base_simple/goal", PoseStamped, self.callback_set_goal, queue_size=1)
|
||||
self.timer_net = rospy.Timer(rospy.Duration(1. / self.config['network_frequency']), self.test_policy)
|
||||
print("YOPO Net Node Ready!")
|
||||
rospy.spin()
|
||||
|
||||
# the first frame
|
||||
def callback_odometry(self, data):
|
||||
self.odom = data
|
||||
if not self.odom_ref_init:
|
||||
self.new_odom = True
|
||||
|
||||
# the following frame (The planner is planning from the desired state, instead of the actual state)
|
||||
def callback_odometry_ref(self, data):
|
||||
if not self.odom_ref_init:
|
||||
print("odom ref init")
|
||||
self.odom_ref_init = True
|
||||
self.odom_ref = data
|
||||
self.new_odom = True
|
||||
|
||||
def process_odom(self):
|
||||
# Rwb
|
||||
Rotation_wb = R.from_quat([self.odom.pose.pose.orientation.x,
|
||||
self.odom.pose.pose.orientation.y,
|
||||
self.odom.pose.pose.orientation.z,
|
||||
self.odom.pose.pose.orientation.w]).as_matrix()
|
||||
self.Rotation_wc = np.dot(Rotation_wb, self.Rotation_bc)
|
||||
|
||||
if self.odom_ref_init:
|
||||
odom_data = self.odom_ref
|
||||
# vel_b
|
||||
vel_w = np.array([odom_data.twist.twist.linear.x,
|
||||
odom_data.twist.twist.linear.y,
|
||||
odom_data.twist.twist.linear.z])
|
||||
vel_b = np.dot(np.linalg.inv(self.Rotation_wc), vel_w)
|
||||
# acc_b
|
||||
acc_w = np.array([odom_data.twist.twist.angular.x, # acc stored in angular in our ref_state topic
|
||||
odom_data.twist.twist.angular.y,
|
||||
odom_data.twist.twist.angular.z])
|
||||
acc_b = np.dot(np.linalg.inv(self.Rotation_wc), acc_w)
|
||||
else:
|
||||
odom_data = self.odom
|
||||
vel_b = np.array([0.0, 0.0, 0.0])
|
||||
acc_b = np.array([0.0, 0.0, 0.0])
|
||||
|
||||
# pose
|
||||
pos = np.array([odom_data.pose.pose.position.x,
|
||||
odom_data.pose.pose.position.y,
|
||||
odom_data.pose.pose.position.z])
|
||||
|
||||
# goal_dir
|
||||
goal_w = (self.goal - pos) / np.linalg.norm(self.goal - pos)
|
||||
goal_b = np.dot(np.linalg.inv(self.Rotation_wc), goal_w)
|
||||
|
||||
vel_acc = np.concatenate((vel_b, acc_b), axis=0)
|
||||
vel_acc_norm = self.normalize_obs(vel_acc[np.newaxis, :])
|
||||
obs_norm = np.hstack((vel_acc_norm, goal_b[np.newaxis, :]))
|
||||
return obs_norm
|
||||
|
||||
def callback_depth(self, data):
|
||||
max_dis = 20.0
|
||||
min_dis = 0.03
|
||||
if self.env == '435':
|
||||
scale = 0.001
|
||||
elif self.env == 'flightmare':
|
||||
scale = 1.0
|
||||
|
||||
try:
|
||||
depth_ = self.bridge.imgmsg_to_cv2(data, "32FC1")
|
||||
except:
|
||||
print("CV_bridge ERROR: The ROS path is not included in Python Path!")
|
||||
|
||||
if depth_.shape[0] != self.height or depth_.shape[1] != self.width:
|
||||
depth_ = cv2.resize(depth_, (self.width, self.height), interpolation=cv2.INTER_NEAREST)
|
||||
depth_ = np.minimum(depth_ * scale, max_dis) / max_dis
|
||||
|
||||
# interpolated the nan value (experiment shows that treating nan directly as 0 produces similar results)
|
||||
start = time.time()
|
||||
nan_mask = np.isnan(depth_) | (depth_ < min_dis)
|
||||
interpolated_image = cv2.inpaint(np.uint8(depth_ * 255), np.uint8(nan_mask), 1, cv2.INPAINT_NS)
|
||||
interpolated_image = interpolated_image.astype(np.float32) / 255.0
|
||||
depth_ = interpolated_image.reshape([1, 1, self.height, self.width])
|
||||
if self.verbose:
|
||||
self.time_interpolation = self.time_interpolation + (time.time() - start)
|
||||
self.count_interpolation = self.count_interpolation + 1
|
||||
print("interpolation time:", self.time_interpolation / self.count_interpolation)
|
||||
|
||||
# cv2.imshow("1", depth_[0][0])
|
||||
# cv2.waitKey(1)
|
||||
self.new_depth = True
|
||||
self.depth = depth_.astype(np.float32)
|
||||
|
||||
def callback_set_goal(self, data):
|
||||
self.goal = np.asarray([data.pose.position.x, data.pose.position.y, 2])
|
||||
print("New goal:", self.goal)
|
||||
|
||||
def test_policy(self, _timer):
|
||||
if self.new_depth and self.new_odom:
|
||||
self.new_odom = False
|
||||
self.new_depth = False
|
||||
obs = self.process_odom()
|
||||
odom_sec = self.odom.header.stamp.to_sec()
|
||||
|
||||
# input prepare
|
||||
time0 = time.time()
|
||||
depth = torch.from_numpy(self.depth).to(self.device, non_blocking=True) # (non_blocking: copying speed 3x)
|
||||
obs_norm_input = self.prepare_input_observation(obs)
|
||||
obs_norm_input = obs_norm_input.to(self.device, non_blocking=True)
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
# forward
|
||||
if self.use_trt: # TensorRT (inference speed increased by 10x)
|
||||
time1 = time.time()
|
||||
trt_output = self.policy(depth, obs_norm_input)
|
||||
time2 = time.time()
|
||||
endstate_pred, score_pred = self.trt_process(trt_output, return_all_preds=self.visualize)
|
||||
endstate_pred = endstate_pred.squeeze()
|
||||
time3 = time.time()
|
||||
else:
|
||||
endstate_pred, score_pred = self.policy.predict(depth, obs_norm_input, return_all_preds=self.visualize)
|
||||
endstate_pred = endstate_pred.cpu().numpy().squeeze()
|
||||
score_pred = score_pred.cpu().numpy()
|
||||
|
||||
# Transform the prediction(body frame) to the world frame with the attitude in inference
|
||||
# Replacing PyTorch calculations on CUDA with NumPy calculations on the CPU (speed increased by 10x)
|
||||
endstate_b = endstate_pred
|
||||
endstate_w = np.zeros_like(endstate_b)
|
||||
traj_num = self.lattice_space.horizon_num * self.lattice_space.vertical_num if self.visualize else 1
|
||||
Pb, Vb, Ab = [np.zeros((3, traj_num)) for _ in range(3)]
|
||||
for i in range(3):
|
||||
Pb[i] = endstate_b[3 * i]
|
||||
Vb[i] = endstate_b[3 * i + 1]
|
||||
Ab[i] = endstate_b[3 * i + 2]
|
||||
# pos_actual = np.array([self.odom.pose.pose.position.x,
|
||||
# self.odom.pose.pose.position.y,
|
||||
# self.odom.pose.pose.position.z])
|
||||
Pw = np.dot(self.Rotation_wc, Pb) # + np.tile(pos_actual, (15, 1)).T
|
||||
Vw = np.dot(self.Rotation_wc, Vb)
|
||||
Aw = np.dot(self.Rotation_wc, Ab)
|
||||
for i in range(3):
|
||||
endstate_w[3 * i] = Pw[i]
|
||||
endstate_w[3 * i + 1] = Vw[i]
|
||||
endstate_w[3 * i + 2] = Aw[i]
|
||||
|
||||
if self.verbose:
|
||||
self.time_prepare = self.time_prepare + (time1 - time0)
|
||||
self.time_forward = self.time_forward + (time2 - time1)
|
||||
self.time_process = self.time_process + (time3 - time2)
|
||||
self.count = self.count + 1
|
||||
print("time forward:", self.time_forward / self.count, "process:", self.time_process / self.count,
|
||||
"prepare:", self.time_prepare / self.count)
|
||||
|
||||
# publish
|
||||
if not self.visualize:
|
||||
endstate_pred_to_pub = Float32MultiArray(data=endstate_w.reshape(-1))
|
||||
endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms)
|
||||
self.endstate_pub.publish(endstate_pred_to_pub)
|
||||
else:
|
||||
action_id = np.argmin(score_pred)
|
||||
best_endstate_pred = endstate_w[:, action_id].reshape(-1)
|
||||
endstate_pred_to_pub = Float32MultiArray(data=best_endstate_pred)
|
||||
endstate_pred_to_pub.layout.data_offset = int(1000 * odom_sec) % 1000000 # 预测时用的里程计时间戳(ms)
|
||||
self.endstate_pub.publish(endstate_pred_to_pub)
|
||||
# visualization
|
||||
endstate_score_preds = np.concatenate((endstate_w, score_pred), axis=0)
|
||||
all_endstate_pred = Float32MultiArray(data=endstate_score_preds.T.reshape(-1))
|
||||
all_endstate_pred.layout.dim.append(MultiArrayDimension())
|
||||
all_endstate_pred.layout.dim[0].size = endstate_score_preds.shape[1]
|
||||
all_endstate_pred.layout.dim[0].label = "primitive_num"
|
||||
all_endstate_pred.layout.dim.append(MultiArrayDimension())
|
||||
all_endstate_pred.layout.dim[1].size = endstate_score_preds.shape[0]
|
||||
all_endstate_pred.layout.dim[1].label = "endstate_and_score_num"
|
||||
self.all_endstate_pub.publish(all_endstate_pred)
|
||||
self.goal_pub.publish(Float32MultiArray(data=self.goal))
|
||||
else:
|
||||
if not self.new_odom: # start a new round
|
||||
self.odom_ref_init = False
|
||||
|
||||
def trt_process(self, input_tensor: torch.Tensor, return_all_preds=False) -> torch.Tensor:
|
||||
batch_size = input_tensor.shape[0]
|
||||
input_tensor = input_tensor.cpu().numpy()
|
||||
input_tensor = input_tensor.reshape(batch_size, 10,
|
||||
self.lattice_space.horizon_num * self.lattice_space.vertical_num)
|
||||
endstate_pred = input_tensor[:, 0:9, :]
|
||||
score_pred = input_tensor[:, 9, :]
|
||||
|
||||
if not return_all_preds:
|
||||
endstate_prediction = np.zeros((batch_size, 9))
|
||||
score_prediction = np.zeros((batch_size, 1))
|
||||
for i in range(0, batch_size):
|
||||
action_id = np.argmin(score_pred[i])
|
||||
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - action_id
|
||||
endstate_prediction[i] = self.pred_to_endstate(np.expand_dims(endstate_pred[i, :, action_id], axis=0), lattice_id)
|
||||
score_prediction[i] = score_pred[i, action_id]
|
||||
else:
|
||||
endstate_prediction = np.zeros_like(endstate_pred)
|
||||
score_prediction = score_pred
|
||||
for i in range(0, self.lattice_space.horizon_num * self.lattice_space.vertical_num):
|
||||
lattice_id = self.lattice_space.horizon_num * self.lattice_space.vertical_num - 1 - i
|
||||
endstate = self.pred_to_endstate(endstate_pred[:, :, i], lattice_id)
|
||||
endstate_prediction[:, :, i] = endstate
|
||||
|
||||
return endstate_prediction, score_prediction
|
||||
|
||||
def prepare_input_observation(self, obs):
|
||||
"""
|
||||
convert the observation from body frame to primitive frame,
|
||||
and then concatenate it with the depth features (to ensure the translational invariance)
|
||||
"""
|
||||
obs_return = np.ones(
|
||||
(obs.shape[0], self.lattice_space.vertical_num, self.lattice_space.horizon_num, obs.shape[1]),
|
||||
dtype=np.float32)
|
||||
id = 0
|
||||
v_b = obs[:, 0:3]
|
||||
a_b = obs[:, 3:6]
|
||||
g_b = obs[:, 6:9]
|
||||
for i in range(self.lattice_space.vertical_num - 1, -1, -1):
|
||||
for j in range(self.lattice_space.horizon_num - 1, -1, -1):
|
||||
Rbp = self.lattice_primitive.getRotation(id)
|
||||
v_p = np.dot(Rbp.T, v_b.T).T
|
||||
a_p = np.dot(Rbp.T, a_b.T).T
|
||||
g_p = np.dot(Rbp.T, g_b.T).T
|
||||
obs_return[:, i, j, 0:3] = v_p
|
||||
obs_return[:, i, j, 3:6] = a_p
|
||||
obs_return[:, i, j, 6:9] = g_p
|
||||
# obs_return[:, i, j, 0:6] = self.normalize_obs(obs_return[:, i, j, 0:6])
|
||||
id = id + 1
|
||||
obs_return = np.transpose(obs_return, [0, 3, 1, 2])
|
||||
return torch.from_numpy(obs_return)
|
||||
|
||||
def pred_to_endstate(self, endstate_pred: np.ndarray, id: int):
|
||||
"""
|
||||
Transform the predicted state to the body frame.
|
||||
"""
|
||||
delta_yaw = endstate_pred[:, 0] * self.lattice_primitive.yaw_diff
|
||||
delta_pitch = endstate_pred[:, 1] * self.lattice_primitive.pitch_diff
|
||||
radio = endstate_pred[:, 2] * self.lattice_space.radio_range + self.lattice_space.radio_range
|
||||
yaw, pitch = self.lattice_primitive.getAngleLattice(id)
|
||||
endstate_x = np.cos(pitch + delta_pitch) * np.cos(yaw + delta_yaw) * radio
|
||||
endstate_y = np.cos(pitch + delta_pitch) * np.sin(yaw + delta_yaw) * radio
|
||||
endstate_z = np.sin(pitch + delta_pitch) * radio
|
||||
endstate_p = np.stack((endstate_x, endstate_y, endstate_z), axis=1)
|
||||
|
||||
endstate_vp = endstate_pred[:, 3:6] * self.lattice_space.vel_max
|
||||
endstate_ap = endstate_pred[:, 6:9] * self.lattice_space.acc_max
|
||||
Rbp = self.lattice_primitive.getRotation(id)
|
||||
endstate_vb = np.matmul(Rbp, endstate_vp.T).T
|
||||
endstate_ab = np.matmul(Rbp, endstate_ap.T).T
|
||||
endstate = np.concatenate((endstate_p, endstate_vb, endstate_ab), axis=1)
|
||||
endstate[:, [0, 1, 2, 3, 4, 5, 6, 7, 8]] = endstate[:, [0, 3, 6, 1, 4, 7, 2, 5, 8]]
|
||||
return endstate
|
||||
|
||||
def normalize_obs(self, vel_acc):
|
||||
vel_norm = vel_acc[:, 0:3] / self.lattice_space.vel_max
|
||||
acc_norm = vel_acc[:, 3:6] / self.lattice_space.acc_max
|
||||
return np.hstack((vel_norm, acc_norm))
|
||||
|
||||
def warm_up(self):
|
||||
depth = np.zeros(shape=[1, 1, self.height, self.width], dtype=np.float32)
|
||||
obs = np.zeros(shape=[1, 9], dtype=np.float32)
|
||||
obs_input = self.prepare_input_observation(obs)
|
||||
if self.use_trt:
|
||||
trt_output = self.policy(torch.from_numpy(depth).to(self.device), obs_input.to(self.device))
|
||||
self.trt_process(trt_output, return_all_preds=True)
|
||||
else:
|
||||
self.policy.predict(torch.from_numpy(depth).to(self.device), obs_input.to(self.device),
|
||||
return_all_preds=True)
|
||||
|
||||
|
||||
def parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--use_tensorrt", type=int, default=0, help="use tensorrt or not")
|
||||
parser.add_argument("--trial", type=int, default=1, help="trial number")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
|
||||
parser.add_argument("--iter", type=int, default=0, help="iter number")
|
||||
return parser
|
||||
|
||||
|
||||
# In realworld flight: visualize=False; use_tensorrt=True, and ensure the pitch_angle consistent with your platform
|
||||
# When modifying the pitch_angle, there's no need to re-collect and re-train, as all predictions are in the camera coordinate system
|
||||
def main():
|
||||
args = parser().parse_args()
|
||||
rsg_root = os.path.dirname(os.path.abspath(__file__))
|
||||
if args.use_tensorrt:
|
||||
weight = "yopo_trt.pth"
|
||||
else:
|
||||
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
|
||||
|
||||
settings = {'use_tensorrt': args.use_tensorrt,
|
||||
'network_frequency': 30,
|
||||
'img_height': 96,
|
||||
'img_width': 160,
|
||||
'goal': [20, 20, 2], # the goal
|
||||
'env': 'flightmare', # use Realsense D435 or Flightmare Simulator ('435' or 'flightmare')
|
||||
'pitch_angle_deg': -5, # pitch of camera, ensure consistent with the simulator or your platform (no need to re-collect and re-train when modifying)
|
||||
'odom_topic': '/juliett/ground_truth/odom',
|
||||
'depth_topic': '/depth_image',
|
||||
'verbose': False, # print the latency?
|
||||
'visualize': True # visualize all predictions? set False in real flight
|
||||
}
|
||||
YopoNet(settings, weight)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
25
run/utils/log_plot.py
Executable file
25
run/utils/log_plot.py
Executable file
@@ -0,0 +1,25 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if __name__ == '__main__':
|
||||
file_path = "/home/lu/flightmare/flightmare/run/utils/dist.csv"
|
||||
temp = np.loadtxt(file_path, dtype=float, delimiter=",")
|
||||
file_path = "/home/lu/flightmare/flightmare/run/utils/dist_x.csv"
|
||||
tempX = np.loadtxt(file_path, dtype=float, delimiter=",")
|
||||
plt.plot(tempX, temp)
|
||||
plt.show()
|
||||
print("dist min:", np.min(temp))
|
||||
file_path = "/home/lu/flightmare/flightmare/run/utils/ctrl_log.csv"
|
||||
ctrl_log = np.loadtxt(file_path, dtype=float, delimiter=",")
|
||||
v_total = np.sqrt(
|
||||
ctrl_log[:, 3] * ctrl_log[:, 3] + ctrl_log[:, 4] * ctrl_log[:, 4] + ctrl_log[:, 5] * ctrl_log[:, 5])
|
||||
print("v max: ", np.max(v_total))
|
||||
plt.plot(ctrl_log[:, 3], label='vx')
|
||||
plt.plot(ctrl_log[:, 4], label='vy')
|
||||
plt.plot(ctrl_log[:, 5], label='vz')
|
||||
plt.plot(v_total, label='v_total')
|
||||
plt.plot(ctrl_log[:, 6], label='ax')
|
||||
plt.plot(ctrl_log[:, 7], label='ay')
|
||||
plt.plot(ctrl_log[:, 8], label='az')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
52
run/utils/pointcloud_clip.py
Normal file
52
run/utils/pointcloud_clip.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# 实飞数据训练:将全局地图裁剪并保存
|
||||
# 1、注意数据收集时,地面尽量平,且需要为z=0
|
||||
# 2、收集数据不平时,修改yaw_angle_radians, pitch_angle_radians平移,并与data collection一致
|
||||
# 3、bug:需要打开保存的文件,手动把前面几行的double改成float...
|
||||
|
||||
import open3d as o3d
|
||||
import numpy as np
|
||||
|
||||
# 1. 加载点云数据
|
||||
point_cloud = o3d.io.read_point_cloud("1.pcd") # 替换为点云文件的路径
|
||||
|
||||
|
||||
# # 统计离群点移除滤波
|
||||
# cl, ind = cropped_point_cloud.remove_statistical_outlier(nb_neighbors=5, std_ratio=1.0) # 调整参数以控制移除离群点的程度
|
||||
# filtered_cloud = cropped_point_cloud.select_by_index(ind)
|
||||
|
||||
# 2. 定义旋转角度(偏航角和俯仰角)
|
||||
yaw_angle_degrees = -15 # 偏航角(以度为单位)
|
||||
pitch_angle_degrees = -3 # 俯仰角(以度为单位)
|
||||
# 3. 将角度转换为弧度
|
||||
yaw_angle_radians = np.radians(yaw_angle_degrees)
|
||||
pitch_angle_radians = np.radians(pitch_angle_degrees)
|
||||
|
||||
yaw_rotation = np.array([[np.cos(yaw_angle_radians), -np.sin(yaw_angle_radians), 0],
|
||||
[np.sin(yaw_angle_radians), np.cos(yaw_angle_radians), 0],
|
||||
[0, 0, 1]])
|
||||
|
||||
pitch_rotation = np.array([[np.cos(pitch_angle_radians), 0, np.sin(pitch_angle_radians)],
|
||||
[0, 1, 0],
|
||||
[-np.sin(pitch_angle_radians), 0, np.cos(pitch_angle_radians)]])
|
||||
# 4. 平移2米到Z方向
|
||||
translation_no = np.array([0, 0, 2]) # 平移2米到Z方向
|
||||
|
||||
# 5. 组合旋转矩阵 R old->new
|
||||
R_on = np.dot(yaw_rotation, pitch_rotation) # 内旋是右乘,先yaw后pitch
|
||||
# P_n = (R_no * P_o.T).T + t_no = P_o * R_on + t_no
|
||||
point_cloud.points = o3d.utility.Vector3dVector(np.dot(np.asarray(point_cloud.points), R_on) + translation_no)
|
||||
|
||||
# o3d.visualization.draw_geometries([point_cloud])
|
||||
|
||||
|
||||
# 2. 定义裁剪范围
|
||||
# 例如,裁剪一个立方体范围,这里给出立方体的最小点和最大点坐标
|
||||
min_bound = np.array([-5.0, -18.0, 0]) # 最小点坐标
|
||||
max_bound = np.array([150.0, 25.0, 6]) # 最大点坐标
|
||||
|
||||
# 3. 使用crop函数裁剪点云
|
||||
cropped_point_cloud = point_cloud.crop(o3d.geometry.AxisAlignedBoundingBox(min_bound, max_bound))
|
||||
|
||||
o3d.io.write_point_cloud("realworld.ply", cropped_point_cloud, write_ascii=True)
|
||||
|
||||
o3d.visualization.draw_geometries([cropped_point_cloud])
|
||||
101
run/yopo_trt_transfer.py
Normal file
101
run/yopo_trt_transfer.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
将yopo模型转换为Tensorrt
|
||||
prepare:
|
||||
1 pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com
|
||||
2 git clone https://github.com/NVIDIA-AI-IOT/torch2trt
|
||||
cd torch2trt
|
||||
python setup.py install
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch2trt import torch2trt
|
||||
from flightgym import QuadrotorEnv_v1
|
||||
from ruamel.yaml import YAML, RoundTripDumper, dump
|
||||
from flightpolicy.envs import vec_env_wrapper as wrapper
|
||||
from flightpolicy.yopo.yopo_algorithm import YopoAlgorithm
|
||||
|
||||
|
||||
def prapare_input_observation(obs, lattice_space, lattice_primitive):
|
||||
obs_return = np.ones(
|
||||
(obs.shape[0], lattice_space.vertical_num, lattice_space.horizon_num, obs.shape[1]),
|
||||
dtype=np.float32)
|
||||
id = 0
|
||||
v_b = obs[:, 0:3]
|
||||
a_b = obs[:, 3:6]
|
||||
g_b = obs[:, 6:9]
|
||||
for i in range(lattice_space.vertical_num - 1, -1, -1):
|
||||
for j in range(lattice_space.horizon_num - 1, -1, -1):
|
||||
Rbp = lattice_primitive.getRotation(id)
|
||||
v_p = np.dot(Rbp.T, v_b.T).T
|
||||
a_p = np.dot(Rbp.T, a_b.T).T
|
||||
g_p = np.dot(Rbp.T, g_b.T).T
|
||||
obs_return[:, i, j, 0:3] = v_p
|
||||
obs_return[:, i, j, 3:6] = a_p
|
||||
obs_return[:, i, j, 6:9] = g_p
|
||||
id = id + 1
|
||||
obs_return = np.transpose(obs_return, [0, 3, 1, 2])
|
||||
return obs_return
|
||||
|
||||
|
||||
def parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--trial", type=int, default=1, help="trial number")
|
||||
parser.add_argument("--epoch", type=int, default=0, help="epoch number")
|
||||
parser.add_argument("--iter", type=int, default=0, help="iter number")
|
||||
parser.add_argument("--dir", type=str, default='yopo_trt.pth', help="output file name")
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = parser().parse_args()
|
||||
# load configurations
|
||||
cfg = YAML().load(open(os.environ["FLIGHTMARE_PATH"] + "/flightlib/configs/vec_env.yaml", 'r'))
|
||||
cfg["env"]["num_envs"] = 1
|
||||
cfg["env"]["supervised"] = False
|
||||
cfg["env"]["imitation"] = False
|
||||
cfg["env"]["render"] = False
|
||||
|
||||
# create environment
|
||||
train_env = QuadrotorEnv_v1(dump(cfg, Dumper=RoundTripDumper), False)
|
||||
train_env = wrapper.FlightEnvVec(train_env)
|
||||
model = YopoAlgorithm(env=train_env,
|
||||
policy_kwargs=dict(
|
||||
activation_fn=torch.nn.ReLU,
|
||||
net_arch=[256, 256],
|
||||
hidden_state=64
|
||||
))
|
||||
|
||||
rsg_root = os.path.dirname(os.path.abspath(__file__))
|
||||
weight = rsg_root + "/saved/YOPO_{}/Policy/epoch{}_iter{}.pth".format(args.trial, args.epoch, args.iter)
|
||||
device = torch.device("cuda")
|
||||
saved_variables = torch.load(weight, map_location=device)
|
||||
model.policy.load_state_dict(saved_variables["state_dict"], strict=False)
|
||||
model.policy.set_training_mode(False)
|
||||
|
||||
lattice_space = saved_variables["data"]["lattice_space"]
|
||||
lattice_primitive = saved_variables["data"]["lattice_primitive"]
|
||||
|
||||
# The inputs should be consistent with training
|
||||
depth = np.zeros(shape=[1, 1, 96, 160], dtype=np.float32)
|
||||
obs = np.zeros(shape=[1, 9], dtype=np.float32)
|
||||
obs_input = prapare_input_observation(obs, lattice_space, lattice_primitive)
|
||||
depth_in = torch.from_numpy(depth).cuda()
|
||||
obs_in = torch.from_numpy(obs_input).cuda()
|
||||
model_trt = torch2trt(model.policy, [depth_in, obs_in])
|
||||
torch.save(model_trt.state_dict(), args.dir)
|
||||
|
||||
# from torch2trt import TRTModule
|
||||
# model_trt = TRTModule()
|
||||
# model_trt.load_state_dict(torch.load('yopo_trt.pth'))
|
||||
|
||||
y_trt = model_trt(depth_in, obs_in)
|
||||
y = model.policy(depth_in, obs_in)
|
||||
error = torch.mean(torch.abs(y - y_trt))
|
||||
print("transfer error: ", error)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user