在 PyTorch 中,数据集Dataset和数据加载器DataLoader是高效处理数据的核心组件。
Dataset 类
- 抽象类:Dataset是一个抽象类,不能直接实例化。我们需要定义自己的数据集类,继承Dataset类,并实现其中的方法。
- 可索引:Dataset支持索引操作,可以通过索引获取数据集中的任意数据样本。
- 数据预处理:在Dataset中,我们可以对数据进行预处理、增强或归一化等操作,为后续的模型训练做好准备。
通过 torch.utils.data.Dataset 抽象类定义数据集的基本结构。自定义数据集需要继承 Dataset 并实现两个关键方法:
- __len__(): 返回数据集的总样本数。
- __getitem__(index): 根据索引返回单个样本(数据和标签)。
示例:自定义数据集
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
self.data = data # 数据(如张量、图像路径等)
self.labels = labels # 标签
self.transform = transform # 数据预处理
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample) # 应用预处理
return sample, label
# 使用示例
data = ... # 数据(如 NumPy 数组或列表)
labels = ... # 标签
dataset = CustomDataset(data, labels)
DataLoader 类
torch.utils.data.DataLoader 负责批量加载数据,支持自动批处理、多进程加速和数据打乱。
- 批量加载:DataLoader可以按照指定的batch_size从Dataset中取出一组数据进行加载,方便进行小批量训练。
- 多线程加载:DataLoader支持多线程数据加载,可以显著提高数据加载速度。
- 数据混洗:DataLoader可以在每个epoch开始时对数据集进行混洗,有助于提高模型的泛化能力。
核心参数:
- batch_size: 每批样本数(默认为1)。
- shuffle: 是否在每个 epoch 开始时打乱数据(默认为 False)。
- num_workers: 加载数据的子进程数(建议根据 CPU 核数设置)。
- drop_last: 是否丢弃最后一个不完整的批次(当样本数不能被 batch_size 整除时)。
示例:
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
# 遍历数据
for batch_data, batch_labels in dataloader:
# 输入模型训练
outputs = model(batch_data)
loss = criterion(outputs, batch_labels)
内置数据集
PyTorch 的 torchvision 和 torchtext 库提供了常用数据集(如 MNIST、CIFAR10、ImageNet 等)。
示例:加载 MNIST
import torchvision
from torchvision import transforms
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化
])
# 下载并加载数据集
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transform
)
数据预处理(Transforms)
torchvision.transforms 提供常用的图像预处理方法:
- Resize(): 调整图像尺寸
- RandomCrop(): 随机裁剪
- ToTensor(): 转为 PyTorch 张量
- Normalize(): 归一化
自定义 Transforms
# Lambda 转换示例
transform = transforms.Lambda(lambda x: x * 2) # 自定义操作
多任务/多数据集加载
合并数据集
from torch.utils.data import ConcatDataset
combined_dataset = ConcatDataset([dataset1, dataset2])
dataloader = DataLoader(combined_dataset, batch_size=64)
多模态数据
自定义 Dataset 返回多个数据源:
class MultiModalDataset(Dataset):
def __getitem__(self, idx):
image = self.images[idx]
text = self.texts[idx]
return image, text, labels[idx]
参考文档
- PyTorch Dataset Docs:https://pytorch.org/docs/stable/data.html
- Torchvision Transforms:https://pytorch.org/vision/stable/transforms.html