莫方教程网

专业程序员编程教程与实战案例分享

PyTorch深度学习框架基础——数据集与数据加载方法

在 PyTorch 中,数据集Dataset数据加载器DataLoader是高效处理数据的核心组件。


Dataset 类

  • 抽象类:Dataset是一个抽象类,不能直接实例化。我们需要定义自己的数据集类,继承Dataset类,并实现其中的方法。
  • 可索引:Dataset支持索引操作,可以通过索引获取数据集中的任意数据样本。
  • 数据预处理:在Dataset中,我们可以对数据进行预处理、增强或归一化等操作,为后续的模型训练做好准备。

通过 torch.utils.data.Dataset 抽象类定义数据集的基本结构。自定义数据集需要继承 Dataset 并实现两个关键方法:

  • __len__(): 返回数据集的总样本数。
  • __getitem__(index): 根据索引返回单个样本(数据和标签)。

示例:自定义数据集

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data          # 数据(如张量、图像路径等)
        self.labels = labels      # 标签
        self.transform = transform  # 数据预处理

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)  # 应用预处理
        return sample, label

# 使用示例
data = ...  # 数据(如 NumPy 数组或列表)
labels = ...  # 标签
dataset = CustomDataset(data, labels)

DataLoader 类


torch.utils.data.DataLoader 负责批量加载数据,支持自动批处理、多进程加速和数据打乱。

  • 批量加载:DataLoader可以按照指定的batch_size从Dataset中取出一组数据进行加载,方便进行小批量训练。
  • 多线程加载:DataLoader支持多线程数据加载,可以显著提高数据加载速度。
  • 数据混洗:DataLoader可以在每个epoch开始时对数据集进行混洗,有助于提高模型的泛化能力。

核心参数:

  • batch_size: 每批样本数(默认为1)。
  • shuffle: 是否在每个 epoch 开始时打乱数据(默认为 False)。
  • num_workers: 加载数据的子进程数(建议根据 CPU 核数设置)。
  • drop_last: 是否丢弃最后一个不完整的批次(当样本数不能被 batch_size 整除时)。

示例:

from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

# 遍历数据
for batch_data, batch_labels in dataloader:
    # 输入模型训练
    outputs = model(batch_data)
    loss = criterion(outputs, batch_labels)

内置数据集

PyTorch 的 torchvision 和 torchtext 库提供了常用数据集(如 MNIST、CIFAR10、ImageNet 等)。

示例:加载 MNIST

import torchvision
from torchvision import transforms

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),          # 转为张量
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

# 下载并加载数据集
train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

数据预处理(Transforms)

torchvision.transforms 提供常用的图像预处理方法:

  • Resize(): 调整图像尺寸
  • RandomCrop(): 随机裁剪
  • ToTensor(): 转为 PyTorch 张量
  • Normalize(): 归一化

自定义 Transforms

# Lambda 转换示例
transform = transforms.Lambda(lambda x: x * 2)  # 自定义操作

多任务/多数据集加载

合并数据集

from torch.utils.data import ConcatDataset

combined_dataset = ConcatDataset([dataset1, dataset2])
dataloader = DataLoader(combined_dataset, batch_size=64)

多模态数据

自定义 Dataset 返回多个数据源:

class MultiModalDataset(Dataset):
    def __getitem__(self, idx):
        image = self.images[idx]
        text = self.texts[idx]
        return image, text, labels[idx]

参考文档

  • PyTorch Dataset Docs:https://pytorch.org/docs/stable/data.html
  • Torchvision Transforms:https://pytorch.org/vision/stable/transforms.html
控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言