modify dataloader

This commit is contained in:
TJU-Lu
2025-10-15 21:34:00 +08:00
parent 52e144f20d
commit c445f0727f
2 changed files with 37 additions and 29 deletions

View File

@@ -1,6 +1,7 @@
import os, sys import os, sys
import cv2 import cv2
import time import time
import torch
import numpy as np import numpy as np
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from scipy.spatial.transform import Rotation as R from scipy.spatial.transform import Rotation as R
@@ -44,10 +45,10 @@ class YOPODataset(Dataset):
for data_idx in range(len(datafolders)): for data_idx in range(len(datafolders)):
datafolder = datafolders[data_idx] datafolder = datafolders[data_idx]
image_file_names = [filename image_file_names = [datafolder + "/" + filename
for filename in os.listdir(datafolder) for filename in os.listdir(datafolder)
if os.path.splitext(filename)[1] == '.png'] if os.path.splitext(filename)[1] == '.png']
image_file_names.sort(key=lambda x: int(x.split('.')[0].split("_")[1])) # sort by filename to align with the label image_file_names.sort(key=lambda x: int(os.path.basename(x).split('.')[0].split("_")[1])) # sort by filename to align with the label
states = np.loadtxt(data_dir + f"/pose-{data_idx}.csv", delimiter=',', skiprows=1).astype(np.float32) states = np.loadtxt(data_dir + f"/pose-{data_idx}.csv", delimiter=',', skiprows=1).astype(np.float32)
positions = states[:, 0:3] positions = states[:, 0:3]
@@ -57,28 +58,20 @@ class YOPODataset(Dataset):
image_file_names, positions, quaternions, test_size=val_ratio, random_state=0) image_file_names, positions, quaternions, test_size=val_ratio, random_state=0)
if mode == 'train': if mode == 'train':
images = [cv2.imread(datafolder + "/" + filename, -1).astype(np.float32) for filename in file_names_train] self.img_list.extend(file_names_train)
self.img_list.extend(images)
self.positions = np.vstack((self.positions, positions_train.astype(np.float32))) self.positions = np.vstack((self.positions, positions_train.astype(np.float32)))
self.quaternions = np.vstack((self.quaternions, quaternions_train.astype(np.float32))) self.quaternions = np.vstack((self.quaternions, quaternions_train.astype(np.float32)))
self.map_idx.extend([data_idx] * len(file_names_train))
elif mode == 'valid': elif mode == 'valid':
images = [cv2.imread(datafolder + "/" + filename, -1).astype(np.float32) for filename in file_names_val] self.img_list.extend(file_names_val)
self.img_list.extend(images)
self.positions = np.vstack((self.positions, positions_val.astype(np.float32))) self.positions = np.vstack((self.positions, positions_val.astype(np.float32)))
self.quaternions = np.vstack((self.quaternions, quaternions_val.astype(np.float32))) self.quaternions = np.vstack((self.quaternions, quaternions_val.astype(np.float32)))
self.map_idx.extend([data_idx] * len(file_names_val))
else: else:
raise ValueError(f"Invalid mode {mode}. Choose from 'train', 'valid'.") raise ValueError(f"Invalid mode {mode}. Choose from 'train', 'valid'.")
self.map_idx.extend([data_idx] * len(images))
# NOTE: The depth images are normalized from 020m to a 01 and converted to int16 during data collection.
self.img_list = [np.expand_dims(
cv2.resize(img, (self.width, self.height), interpolation=cv2.INTER_NEAREST) / 65535.0,
axis=0)
for img in self.img_list]
print(f"=============== {mode.capitalize()} Data Summary ===============") print(f"=============== {mode.capitalize()} Data Summary ===============")
print(f"{'Images' :<12} | Count: {len(self.img_list):<3} | Shape: {self.img_list[0].shape}") print(f"{'Images' :<12} | Count: {len(self.img_list):<3} | Shape: {self.width},{self.height}")
print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}") print(f"{'Positions' :<12} | Count: {self.positions.shape[0]:<3} | Shape: {self.positions.shape[1]}")
print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}") print(f"{'Quaternions' :<12} | Count: {self.quaternions.shape[0]:<3} | Shape: {self.quaternions.shape[1]}")
print("==================================================") print("==================================================")
@@ -88,9 +81,15 @@ class YOPODataset(Dataset):
return len(self.img_list) return len(self.img_list)
def __getitem__(self, item): def __getitem__(self, item):
# 1. read the image
# NOTE: The depth images are normalized from 020m to a 01 and converted to int16 during data collection.
image = cv2.imread(self.img_list[item], -1).astype(np.float32)
image = np.expand_dims(cv2.resize(image, (self.width, self.height), interpolation=cv2.INTER_NEAREST) / 65535.0, axis=0)
# 2. get random vel, acc
vel_b, acc_b = self._get_random_state() vel_b, acc_b = self._get_random_state()
# generate random goal in front of the quadrotor. # 3. generate random goal in front of the quadrotor.
q_wxyz = self.quaternions[item, :] # q: wxyz q_wxyz = self.quaternions[item, :] # q: wxyz
R_WB = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]]) R_WB = R.from_quat([q_wxyz[1], q_wxyz[2], q_wxyz[3], q_wxyz[0]])
euler_angles = R_WB.as_euler('ZYX', degrees=False) # [yaw(z) pitch(y) roll(x)] euler_angles = R_WB.as_euler('ZYX', degrees=False) # [yaw(z) pitch(y) roll(x)]
@@ -101,7 +100,7 @@ class YOPODataset(Dataset):
random_obs = np.hstack((vel_b, acc_b, goal_b)).astype(np.float32) random_obs = np.hstack((vel_b, acc_b, goal_b)).astype(np.float32)
rot_wb = R_WB.as_matrix().astype(np.float32) # transform to rot_matrix in numpy is faster than using quat in pytorch rot_wb = R_WB.as_matrix().astype(np.float32) # transform to rot_matrix in numpy is faster than using quat in pytorch
# vel & acc & goal are in body frame, NWU, and no-normalization # vel & acc & goal are in body frame, NWU, and no-normalization
return self.img_list[item], self.positions[item], rot_wb, random_obs, self.map_idx[item] return image, self.positions[item], rot_wb, random_obs, self.map_idx[item]
def _get_random_state(self): def _get_random_state(self):
while True: while True:
@@ -212,14 +211,23 @@ class YOPODataset(Dataset):
if __name__ == '__main__': if __name__ == '__main__':
dataset = YOPODataset() dataset = YOPODataset()
dataset.plot_sample_distribution() # dataset.plot_sample_distribution()
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
dataset = YOPODataset()
max_workers = os.cpu_count()
print(f"\n✅ cpu_count = {max_workers}")
results = []
for nw in range(0, max_workers + 1):
data_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=nw)
start = time.time() start = time.time()
for epoch in range(1): for i, _ in enumerate(data_loader):
last = time.time() if i > 50: # 只测前50个batch
for i, (depth, pos, quat, obs, id) in enumerate(data_loader): break
pass torch.cuda.synchronize() if torch.cuda.is_available() else None
end = time.time() elapsed = time.time() - start
results.append((nw, elapsed))
print(f"num_workers={nw}: {elapsed:.3f}s")
print("加载1个epoch总耗时", end - start) best = min(results, key=lambda x: x[1])
print(f"\n✅ 最优 num_workers = {best[0]}, 平均耗时={best[1]:.3f}s")

View File

@@ -57,11 +57,11 @@ class YopoTrainer:
self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True) self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=learning_rate, fused=True)
print("Network Loaded! Loading Dataset...") print("Network Loaded! Loading Dataset...")
# dataset # dataset (you can adjust num_workers according to your training speed)
self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True, self.train_dataloader = DataLoader(YOPODataset(mode='train'), batch_size=self.batch_size, shuffle=True,
num_workers=1, pin_memory=True) num_workers=4, pin_memory=True)
self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False, self.val_dataloader = DataLoader(YOPODataset(mode='valid'), batch_size=self.batch_size, shuffle=False,
num_workers=1, pin_memory=True) num_workers=4, pin_memory=True)
print("Dataset Loaded!") print("Dataset Loaded!")
def train(self, epoch, save_interval=None): def train(self, epoch, save_interval=None):