8 – 数据集加载

在第6部分,模型训练章节,代码段6-1展示了数据集(Dataset)和数据加载器(DataLoader)的配合示例,本文详细拆解一下二者的工作机制,下面是代码段6-1,这里再贴出来,代码段8-1。

import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler

class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 返回数据集中的第 idx 个元素
        print(f"Fetching item with index: {idx}")
        return self.data[idx]

# 示例数据集
data = [i for i in range(100)]
dataset = SimpleDataset(data)

# 初始化 DataLoader
batch_size = 16
sampler = RandomSampler(dataset)  # 使用随机采样器
# sampler = SequentialSampler(dataset)  # 使用顺序采样器
data_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=2,
    collate_fn=lambda x: x  # 自定义批处理函数(这里使用默认行为)
)
# 遍历 DataLoader
for batch_idx, batch in enumerate(data_loader):
    print(f"Batch {batch_idx + 1}: {batch}")

Dataset

上述代码中,类SimpleDataset 继承了类Dataset,并重写了基类的__getitem__ __len__函数。代码逻辑很简单,__len__ 获取数据元素长度,__getitem__ 根据下标获取指定元素。

在 Python 中,__len__ 和 __getitem__ 是两个特殊方法(也称为魔术方法或双下划线方法),它们分别用于定义对象的长度和索引访问行为。这些方法通常在自定义容器类中使用,使得这些类的行为类似于内置的序列类型),比如只要某个类实现了__len____getitem__ 两个方法,这个类就能用在任何期待序列的地方。而某个类为了支持迭代,只需实现 __getitem__ 方法,没必要提供 __len__方法。

回过来,我们看下Dataset类的实现,代码段8-2:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`. Subclasses could also
    optionally implement :meth:`__getitems__`, for speedup batched samples
    loading. This method accepts list of indices of samples of batch and returns
    list of samples.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs an index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")


    def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
        return ConcatDataset([self, other])

  类Dataset定义了一些方法,其中 __getitem__ 是必须实现的,而 __len__ 和 __getitems_ 是可选实现的,前者返回数据集的大小,后者用于实现批量样本的加载。

  在Pytorch中,所有表示从键到数据样本映射的数据集都应该继承自Dataset 类。这意味着如果你有一个数据集,其中每个数据样本可以通过一个键(通常是索引)来访问,那么你应该创建一个继承自 Dataset 的子类。所有子类都应该重写 __getitem__ 方法,以便于支持根据给定的键(通常是索引)获取数据样本,同时建议实际工程中,也在子类中实现__len__方法,下文会说明__len__方法的影响。

  DataLoader

  构建了数据集之后,需要通过DataLoader 类进行数据加载。DataLoader 是 PyTorch 中非常重要的一个组件,它的作用是从 Dataset 中按批次(batch)取出数据,但它并不关心数据是怎么来的 —— 只要 Dataset.__getitem__() 能返回一个样本。代码段8-3:

class DataLoader(Generic[T_co]):
   """
   Args:
        dataset (Dataset): dataset from which to load the data.
        batch_size (int, optional): how many samples per batch to load
            (default: ``1``).
        shuffle (bool, optional): set to ``True`` to have the data reshuffled
            at every epoch (default: ``False``).
        sampler (Sampler or Iterable, optional): defines the strategy to draw
            samples from the dataset. Can be any ``Iterable`` with ``__len__``
            implemented. If specified, :attr:`shuffle` must not be specified.
        batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
            returns a batch of indices at a time. Mutually exclusive with
            :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
            and :attr:`drop_last`.
        num_workers (int, optional): how many subprocesses to use for data
            loading. ``0`` means that the data will be loaded in the main process.
            (default: ``0``)
        collate_fn (Callable, optional): merges a list of samples to form a
            mini-batch of Tensor(s).  Used when using batched loading from a
            map-style dataset.
        pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
            into device/CUDA pinned memory before returning them.  If your data elements
            are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
            see the example below.
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: ``False``)
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: ``0``)
        worker_init_fn (Callable, optional): If not ``None``, this will be called on each
            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
            input, after seeding and before data loading. (default: ``None``)
        multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
            ``None``, the default `multiprocessing context`_ of your operating system will
            be used. (default: ``None``)
        generator (torch.Generator, optional): If not ``None``, this RNG will be used
            by RandomSampler to generate random indexes and multiprocessing to generate
            ``base_seed`` for workers. (default: ``None``)
        prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
            in advance by each worker. ``2`` means there will be a total of
            2 * num_workers batches prefetched across all workers. (default value depends
            on the set value for num_workers. If value of num_workers=0 default is ``None``.
            Otherwise, if value of ``num_workers > 0`` default is ``2``).
        persistent_workers (bool, optional): If ``True``, the data loader will not shut down
            the worker processes after a dataset has been consumed once. This allows to
            maintain the workers `Dataset` instances alive. (default: ``False``)
        pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
            ``True``.
    """

  正确配置 DataLoader 可以显著提升数据预处理的效率,进而加快整体训练速度。针对上面参数列表,有如下设置建议:

   batch_size

  • 建议:根据硬件资源选择合适的批量大小。更大的批量可以利用 GPU 并行计算的优势,但过大的批量可能会导致内存(特指显存)不足。

   shuffle

  • 建议:在训练过程中开启数据打乱 (shuffle=True),有助于模型更好地泛化。但在验证或测试阶段,通常不需要打乱数据。

  num_workers

  • 建议:设置为 CPU 核心数的一半或全部,可以最大化数据加载速度。注意,如果数据集很小或数据加载速度不是瓶颈,则增加 num_workers 可能不会带来性能提升,反而会增加内存消耗。

  collate_fn

  • 建议:当数据集中的样本需要特殊处理时,自定义 collate_fn 函数。例如,文本数据可能需要填充到相同的长度。

  pin_memory

  • 建议:如果使用 GPU 进行训练,设置 pin_memory=True 可以加速从 CPU 到 GPU 的数据传输,使用锁页内存可以提高数据复制的速度。

  drop_last

  • 建议:当批量大小不能整除数据集大小时,考虑是否丢弃最后一个不完整的批量。如果模型训练对每个批量的大小敏感,可以设置 drop_last=True

  timeout

  • 建议:除非遇到数据加载超时的问题,否则通常不需要调整此参数。如果确实需要,可以适当增加超时时间。

  worker_init_fn

  • 建议:当数据集依赖于随机种子初始化(如数据增强)时,可以通过 worker_init_fn 设置不同的随机种子,确保不同进程之间的随机性。除此之外该函数还可以完成初始化资源操作,例如打开文件描述符、连接数据库等。

  prefetch_factor

  • 建议:默认情况下,每个工作进程会提前加载 2 个批量的数据。如果发现数据加载成为瓶颈,可以尝试增加 prefetch_factor,但要避免过度增加导致内存溢出。

   persistent_workers

  • 建议:对于大型数据集或复杂的预处理流程,设置 persistent_workers=True 可以保持工作进程的活跃状态,减少每次重启的开销。

   pin_memory_device

  • 建议:如果使用多个 GPU,可以通过指定 pin_memory_device 来优化数据传输。例如,如果你有一个主 GPU 和一个辅助 GPU,可以将 pin_memory_device 设置为主 GPU 的设备名,以加速数据传输到其他 GPU。

  特别的,针对num_workers大于1的情况,DataLoader类给出了如下的数据处理流程:


    #                main process                              ||
    #                     |                                    ||
    #               {index_queue}                              ||
    #                     |                                    ||
    #              worker processes                            ||     DATA
    #                     |                                    ||
    #            {worker_result_queue}                         ||     FLOW
    #                     |                                    ||
    #      pin_memory_thread of main process                   ||   DIRECTION
    #                     |                                    ||
    #               {data_queue}                               ||
    #                     |                                    ||
    #                data output                               \/
  • 主进程(Main Process)

主进程负责创建 DataLoader 实例,并启动数据加载流程。

主进程将索引放入 index_queue,这些索引由 Sampler 生成,指示哪些数据项需要被加载。

  • 索引队列(Index Queue)

index_queue 是一个队列,用于存储从 Sampler 生成的索引。

主进程将这些索引放入 index_queue,工作进程从中取出索引。

  • 工作进程(Worker Processes)

工作进程从 index_queue 中取出索引,并根据这些索引从数据集中加载数据。

加载的数据被放入 worker_result_queue,准备传递给主进程。

  • 工作结果队列(Worker Result Queue)

worker_result_queue 是一个队列,用于存储工作进程加载的数据。

工作进程将加载的数据放入 worker_result_queue,主进程中的 pin_memory_thread 从中取出数据。

  • 主进程的 Pin Memory 线程(Pin Memory Thread of Main Process)

如果 pin_memory=True,主进程会启动一个 pin_memory_thread。

  pin_memory_thread 从 worker_result_queue 中取出数据,并将数据转换为带有 Pin Memory 的张量,以便更快地传输到 GPU。

  • 数据队列(Data Queue)

data_queue 是一个队列,用于存储处理好的数据。

pin_memory_thread 将处理好的数据放入 data_queue,准备最终输出。

  • 数据输出(Data Output)

最终,数据从 data_queue 中取出,作为 DataLoader 的输出返回给调用端。

模型使用数据

DataLoader 加载的数据,最终会输入给模型,代码段8-4:

def train(dataloader, model, loss_fn , optimizer):
    model.train()
    for batch,(X,y) in enumerate(dataloader):
        X,y = X.to(device), y.to(device)

值得注意的是,代码段8-4 中第4行的输出内容,取决于Dataset 类中__getitem__函数的实现,具体的,代码段8-5。

import torch
class Dataset(torch.utils.data.Dataset):
    def __len__(self):
        return 10000

    def __getitem__(self, any):
        return torch.tensor([[1, 2], [3, 4]])
        #return torch.empty(0)

if __name__ == '__main__':
    dl = torch.utils.data.DataLoader(
        Dataset(),
        batch_size=128,
        num_workers=1
        #drop_last=True  # 显式丢弃最后一个不完整批次
        )

    it = iter(dl)

    for i, x in enumerate(it):
        print(i, x.shape)

第8行的return 决定了最终的数据组织形式。

另外,代码3-4行,意味着数据集总共有10000 个样本。代码13行,设置batch_size = 128,总样本数 / 批量大小 = 10000 / 128 ≈ 78.125。这意味着完整批次数为 78 个(每批 128 个样本),最后一个不完整批次包含 10000 – 78 * 128 = 16 个样本。如果设置 drop_last=True,则会丢弃最后一个不完整批次。

工程实践

在实际工程中,需要处理的数据集动辄TB甚至PB级别,需要考虑硬件资源限制、效率等问题。

流式加载

我们以代码段8-6为例,代码逻辑加载磁盘文件到内存,然后根据索引键访问数据。

class LargeTextDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_length=128):
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        self.load_data()

    def load_data(self):
        with open(self.file_path, 'r', encoding='utf-8') as f:
            for line in f:
                tokens = self.tokenizer.encode(line.strip())[:self.max_length]
                if len(tokens) > 1:
                    self.data.append(tokens)
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.long)

代码段8-6可能存在的问题:

  1. self.data=[] 将所有数据加载到内存中,假设每条文本平均500 tokens,1亿条数据需要约 `100,000,000 * 500 * 4 bytes ≈ 200GB` 内存(以float32为例)

当内存不够时,可用的解决方案是使用小批量分批处理,或者使用流式方式处理。

代码段8-7 使用迭代器进行懒加载数据,模拟流式处理数据。

class StreamTextDataset(IterableDataset):
    def __init__(self, file_path, tokenizer, max_length = 128, shard_id =0, num_shards = 4):
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.shard_id = shard_id
        self.num_shards = num_shards
    
    def _line_generator(self):
        with open(self.file_path, 'r', encoding='utf-8') as f:
            for line in f:
                yield line.strip()

    def __iter__(self):
        generator = self._line_generator()
        for idx, line in enumerate(generator):
            tokens = self.tokenizer.encode(line.strip())[:self.max_length]
            if len(tokens)>=2:
                yield torch.tensor(tokens, dtype= torch.long)

代码段8-7,类 StreamTextDataset继承自IterableDataset,后者用于从数据源创建迭代式的数据集。

执行代码段8-6和8-7,后者内存资源占用几乎可以忽略。

  1. 代码段8-6 还存在性能问题,原因是只有一个主进程在操作数据文件,自然而言的可以想到使用多进程方式,见代码段8-8。
class StreamTextDataset(IterableDataset):
    def __init__(self, file_path, tokenizer, max_length = 128, shard_id =0, num_shards = 4):
        self.file_path = file_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.shard_id = shard_id
        self.num_shards = num_shards
    
    def _line_generator(self):
        with open(self.file_path, 'r', encoding='utf-8') as f:
            for line in f:
                yield line.strip()

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        num_workers = worker_info.num_workers if worker_info else 1
        worker_id = worker_info.id if worker_info else 0

        generator = self._line_generator()
        for idx, line in enumerate(generator):
            if idx % num_workers != worker_id:
                continue
            tokens = self.tokenizer.encode(line.strip())[:self.max_length]
            if len(tokens)>=2:
                yield torch.tensor(tokens, dtype= torch.long)

代码段8-8,使用多个进程来处理数据。这个代码还有要改进的地方,这多个进程会重复读取全量文件数据,改进的方法是使用内存映射,然后不同的进程读取不同的数据块,这里不再赘述。

分片预取

通过预取机制,可以在当前批次的数据正在被处理的同时,预先加载下一个批次的数据。这样,在开始处理下一个批次的数据时,所需的数据已经准备就绪,从而减少等待时间提高整体吞吐量。

关键参数矩阵
参数默认值作用域优化方向
num_workers1进程级CPU 并行度
prefetch_factor2每个 worker流水线深度
persistent_workersFALSE全局worker 生命周期管理
batch_size1数据级GPU 吞吐量
shuffleFALSE数据分布内存访问模式

以上参数协同优化策略参考,

优先级顺序:

  1. 先确定batch_size(基于显存限制和模型需求)。现在很多工作都采用了动态批次调整策略,即在训练过程中逐渐增加批次大小,最终达到百万级别。例如,GPT-3 的批次大小从 32K 个词元逐渐增加到 3.2M个词元,这是因为较小的批次对应反向传播的频率更高,训练早期可以使用少量的数据让模型的损失尽快下降;而较大的批次可以在后期让模型的损失下降地更加稳定,使模型更好地收敛。
  2. 再调整num_workers(基于CPU资源和数据复杂度)。
  3. 最后优化prefetch_factor(平衡内存和加载速度)。

硬件适配:

  1. 高CPU低GPU:增大num_workers和prefetch_factor,减少GPU空闲。
  2. 高GPU低CPU:适当降低num_workers,避免CPU成为瓶颈。

辅助参数:

启用pin_memory=True加速数据从CPU到GPU的传输。

使用persistent_workers=True避免重复创建进程的开销(适用于多epoch训练)。

使用GPU加速

DALI英伟达推出的开源GPU加速数据加载与预处理库,通过将数据预处理任务从CPU卸载到GPU,显著提升训练和推理效率,主要针对图像、音频和视频数据。