modify dataloader
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import os, sys
|
import os, sys
|
||||||
import cv2
|
import cv2
|
||||||
import time
|
import time
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from scipy.spatial.transform import Rotation as R
|
from scipy.spatial.transform import Rotation as R
|
||||||
@@ -44,10 +45,10 @@ class YOPODataset(Dataset):
|
|||||||
for data_idx in range(len(datafolders)):
|
for data_idx in range(len(datafolders)):
|
||||||
datafolder = datafolders[data_idx]
|
datafolder = datafolders[data_idx]
|
||||||
|
|
||||||
image_file_names = [filename
|
image_file_names = [datafolder + "/" + filename
|
||||||
for filename in os.listdir(datafolder)
|
for filename in os.listdir(datafolder)
|
||||||
if os.path.splitext(filename)[1] == '.png']
|
if os.path.splitext(filename)[1] == '.png']
|
||||||
image_file_names.sort(key=lambda x: int(x.split('.')[0].split("_")[1])) # sort by filename to align with the label
|
image_file_names.sort(key=lambda x: int(os.path.basename(x).split('.')[0].split("_")[1])) # sort by filename to align with the label
|
||||||
|
|
||||||
states = np.loadtxt(data_dir + f"/pose-{data_idx}.csv", delimiter=',', skiprows=1).astype(np.float32)
|
states = np.loadtxt(data_dir + f"/pose-{data_idx}.csv", delimiter=',', skiprows=1).astype(np.float32)
|
||||||
positions = states[:, 0:3]
|
positions = states[:, 0:3]
|
||||||
@@ -57,28 +58,20 @@ class YOPODataset(Dataset):
|
|||||||
image_file_names, positions, quaternions, test_size=val_ratio, random_state=0)
|
image_file_names, positions, quaternions, test_size=val_ratio, random_state=0)
|
||||||
|
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
images = [cv2.imread(datafolder + "/" + filename, -1).astype(np.float32) for filename in file_names_train]
|
self.img_list.extend(file_names_train)
|
||||||
self.img_list.extend(images)
|
|
||||||
self.positions = np.vstack((self.positions, positions_train.astype(np.float32)))
|
self.positions = np.vstack((self.positions, positions_train.astype(np.float32)))
|
||||||
self.quaternions = np.vstack((self.quaternions, quaternions_train.astype(np.float32)))
|
self.quaternions = np.vstack((self.quaternions, quaternions_train.astype(np.float32)))
|
||||||
|
self.map_idx.extend([data_idx] * len(file_names_train))
|
||||||
elif mode == 'valid':
|
elif mode == 'valid':
|
||||||
images = [cv2.imread(datafolder + "/" + filename, -1).astype(np.float32) for filename in file_names_val]
|
self.img_list.extend(file_names_val)
|
||||||
self.img_list.extend(images)
|
|
||||||
self.positions = np.vstack((self.positions, positions_val.astype(np.float32)))
|
self.positions = np.vstack((self.positions, positions_val.astype(np.float32)))
|
||||||
self.quaternions = np.vstack((self.quaternions, quaternions_val.astype(np.float32)))
|
self.quaternions = np.vstack((self.quaternions, quaternions_val.astype(np.float32)))
|
||||||
|
self.map_idx.extend([data_idx] * len(file_names_val))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid mode {mode}. Choose from 'train', 'valid'.")
|
raise ValueError(f"Invalid mode {mode}. Choose from 'train', 'valid'.")
|
||||||
|
|
||||||
self.map_idx.extend([data_idx] * len(images))
|
|
||||||
|
|
||||||
# NOTE: The depth images are normalized from 0–20m to a 0–1 and converted to int16 during data collection.
|
|
||||||
self.img_list = [np.expand_dims(
|
|
||||||
cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_NEAREST) / 65535.0,
|
|
||||||
axis=0)
|
|
||||||
for img in self.img_list]
|
|
||||||
|
|
||||||
print(f"=============== {mode.capitalize()} Data Summary ===============")
|
print(f"=============== {mode.capitalize()} Data Summary ===============")
|
||||||
print(f"{'Images' :<12} | Count: {len(self.img_list):<3} | Shape: {self.img_list[0].shape}")
|
print(f"{'Images' :<12} | Count: {len(self.img_list):<3} | Shape: {self.width},{self.height}")
|
||||||
print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}")
|
print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}")
|
||||||
print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}")
|
print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}")
|
||||||
print("==================================================")
|
print("==================================================")
|
||||||
@@ -88,9 +81,15 @@ class YOPODataset(Dataset):
|
|||||||
return len(self.img_list)
|
return len(self.img_list)
|
||||||
|
|
||||||
def __getitem__(self, item):
|
def __getitem__(self, item):
|
||||||
|
# 1. read the image
|
||||||
|
# NOTE: The depth images are normalized from 0–20m to a 0–1 and converted to int16 during data collection.
|
||||||
|
image = cv2.imread(self.img_list[item], -1).astype(np.float32)
|
||||||
|
image = np.expand_dims(cv2.resize(image, (self.width, self.height), interpolation=cv2.INTER_NEAREST) / 65535.0, axis=0)
|
||||||
|
|
||||||
|
# 2. get random vel, acc
|
||||||
vel_b, acc_b = self._get_random_state()
|
vel_b, acc_b = self._get_random_state()
|
||||||
|
|
||||||
# generate random goal in front of the quadrotor.
|
# 3. generate random goal in front of the quadrotor.
|
||||||
q_wxyz = self.quaternions[item, :] # q: wxyz
|
q_wxyz = self.quaternions[item, :] # q: wxyz
|
||||||
R_WB = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]])
|
R_WB = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]])
|
||||||
euler_angles = R_WB.as_euler('ZYX', degrees=False) # [yaw(z) pitch(y) roll(x)]
|
euler_angles = R_WB.as_euler('ZYX', degrees=False) # [yaw(z) pitch(y) roll(x)]
|
||||||
@@ -101,7 +100,7 @@ class YOPODataset(Dataset):
|
|||||||
random_obs = np.hstack((vel_b, acc_b, goal_b)).astype(np.float32)
|
random_obs = np.hstack((vel_b, acc_b, goal_b)).astype(np.float32)
|
||||||
rot_wb = R_WB.as_matrix().astype(np.float32) # transform to rot_matrix in numpy is faster than using quat in pytorch
|
rot_wb = R_WB.as_matrix().astype(np.float32) # transform to rot_matrix in numpy is faster than using quat in pytorch
|
||||||
# vel & acc & goal are in body frame, NWU, and no-normalization
|
# vel & acc & goal are in body frame, NWU, and no-normalization
|
||||||
return self.img_list[item], self.positions[item], rot_wb, random_obs, self.map_idx[item]
|
return image, self.positions[item], rot_wb, random_obs, self.map_idx[item]
|
||||||
|
|
||||||
def _get_random_state(self):
|
def _get_random_state(self):
|
||||||
while True:
|
while True:
|
||||||
@@ -212,14 +211,23 @@ class YOPODataset(Dataset):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
dataset = YOPODataset()
|
dataset = YOPODataset()
|
||||||
dataset.plot_sample_distribution()
|
# dataset.plot_sample_distribution()
|
||||||
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
|
|
||||||
|
|
||||||
start = time.time()
|
dataset = YOPODataset()
|
||||||
for epoch in range(1):
|
max_workers = os.cpu_count()
|
||||||
last = time.time()
|
print(f"\n✅ cpu_count = {max_workers}")
|
||||||
for i, (depth, pos, quat, obs, id) in enumerate(data_loader):
|
|
||||||
pass
|
|
||||||
end = time.time()
|
|
||||||
|
|
||||||
print("加载1个epoch总耗时:", end - start)
|
results = []
|
||||||
|
for nw in range(0, max_workers + 1):
|
||||||
|
data_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=nw)
|
||||||
|
start = time.time()
|
||||||
|
for i, _ in enumerate(data_loader):
|
||||||
|
if i > 50: # 只测前50个batch
|
||||||
|
break
|
||||||
|
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||||
|
elapsed = time.time() - start
|
||||||
|
results.append((nw, elapsed))
|
||||||
|
print(f"num_workers={nw}: {elapsed:.3f}s")
|
||||||
|
|
||||||
|
best = min(results, key=lambda x: x[1])
|
||||||
|
print(f"\n✅ 最优 num_workers = {best[0]}, 平均耗时={best[1]:.3f}s")
|
||||||
|
|||||||
@@ -57,11 +57,11 @@ class YopoTrainer:
|
|||||||
self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True)
|
self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True)
|
||||||
print("Network Loaded! Loading Dataset...")
|
print("Network Loaded! Loading Dataset...")
|
||||||
|
|
||||||
# dataset
|
# dataset (you can adjust num_workers according to your training speed)
|
||||||
self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True,
|
self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True,
|
||||||
num_workers=1, pin_memory=True)
|
num_workers=4, pin_memory=True)
|
||||||
self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False,
|
self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False,
|
||||||
num_workers=1, pin_memory=True)
|
num_workers=4, pin_memory=True)
|
||||||
print("Dataset Loaded!")
|
print("Dataset Loaded!")
|
||||||
|
|
||||||
def train(self, epoch, save_interval=None):
|
def train(self, epoch, save_interval=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user