在第6部分,模型训练章节,代码段6-1展示了数据集(Dataset)和数据加载器(DataLoader)的配合示例,本文详细拆解一下二者的工作机制,下面是代码段6-1,这里再贴出来,代码段8-1。
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler
class SimpleDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# 返回数据集中的第 idx 个元素
print(f"Fetching item with index: {idx}")
return self.data[idx]
# 示例数据集
data = [i for i in range(100)]
dataset = SimpleDataset(data)
# 初始化 DataLoader
batch_size = 16
sampler = RandomSampler(dataset) # 使用随机采样器
# sampler = SequentialSampler(dataset) # 使用顺序采样器
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=2,
collate_fn=lambda x: x # 自定义批处理函数(这里使用默认行为)
)
# 遍历 DataLoader
for batch_idx, batch in enumerate(data_loader):
print(f"Batch {batch_idx + 1}: {batch}")
Dataset
上述代码中,类SimpleDataset 继承了类Dataset,并重写了基类的__getitem__ 和 __len__函数。代码逻辑很简单,__len__ 获取数据元素长度,__getitem__ 根据下标获取指定元素。
在 Python 中,__len__ 和 __getitem__ 是两个特殊方法(也称为魔术方法或双下划线方法),它们分别用于定义对象的长度和索引访问行为。这些方法通常在自定义容器类中使用,使得这些类的行为类似于内置的序列类型),比如只要某个类实现了__len__和 __getitem__ 两个方法,这个类就能用在任何期待序列的地方。而某个类为了支持迭代,只需实现 __getitem__ 方法,没必要提供 __len__方法。
回过来,我们看下Dataset类的实现,代码段8-2:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`. Subclasses could also
optionally implement :meth:`__getitems__`, for speedup batched samples
loading. This method accepts list of indices of samples of batch and returns
list of samples.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs an index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")
def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]":
return ConcatDataset([self, other])
类Dataset定义了一些方法,其中 __getitem__ 是必须实现的,而 __len__ 和 __getitems_ 是可选实现的,前者返回数据集的大小,后者用于实现批量样本的加载。
在Pytorch中,所有表示从键到数据样本映射的数据集都应该继承自Dataset 类。这意味着如果你有一个数据集,其中每个数据样本可以通过一个键(通常是索引)来访问,那么你应该创建一个继承自 Dataset 的子类。所有子类都应该重写 __getitem__ 方法,以便于支持根据给定的键(通常是索引)获取数据样本,同时建议实际工程中,也在子类中实现__len__方法,下文会说明__len__方法的影响。
DataLoader
构建了数据集之后,需要通过DataLoader
类进行数据加载。DataLoader
是 PyTorch 中非常重要的一个组件,它的作用是从 Dataset
中按批次(batch)取出数据,但它并不关心数据是怎么来的 —— 只要 Dataset.__getitem__()
能返回一个样本。代码段8-3:
class DataLoader(Generic[T_co]):
"""
Args:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
and :attr:`drop_last`.
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
collate_fn (Callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
into device/CUDA pinned memory before returning them. If your data elements
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
timeout (numeric, optional): if positive, the timeout value for collecting a batch
from workers. Should always be non-negative. (default: ``0``)
worker_init_fn (Callable, optional): If not ``None``, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: ``None``)
multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
``None``, the default `multiprocessing context`_ of your operating system will
be used. (default: ``None``)
generator (torch.Generator, optional): If not ``None``, this RNG will be used
by RandomSampler to generate random indexes and multiprocessing to generate
``base_seed`` for workers. (default: ``None``)
prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
in advance by each worker. ``2`` means there will be a total of
2 * num_workers batches prefetched across all workers. (default value depends
on the set value for num_workers. If value of num_workers=0 default is ``None``.
Otherwise, if value of ``num_workers > 0`` default is ``2``).
persistent_workers (bool, optional): If ``True``, the data loader will not shut down
the worker processes after a dataset has been consumed once. This allows to
maintain the workers `Dataset` instances alive. (default: ``False``)
pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
``True``.
"""
正确配置 DataLoader 可以显著提升数据预处理的效率,进而加快整体训练速度。针对上面参数列表,有如下设置建议:
batch_size
- 建议:根据硬件资源选择合适的批量大小。更大的批量可以利用 GPU 并行计算的优势,但过大的批量可能会导致内存(特指显存)不足。
shuffle
- 建议:在训练过程中开启数据打乱 (
shuffle=True
),有助于模型更好地泛化。但在验证或测试阶段,通常不需要打乱数据。
num_workers
- 建议:设置为 CPU 核心数的一半或全部,可以最大化数据加载速度。注意,如果数据集很小或数据加载速度不是瓶颈,则增加
num_workers
可能不会带来性能提升,反而会增加内存消耗。
collate_fn
- 建议:当数据集中的样本需要特殊处理时,自定义
collate_fn
函数。例如,文本数据可能需要填充到相同的长度。
pin_memory
- 建议:如果使用 GPU 进行训练,设置
pin_memory=True
可以加速从 CPU 到 GPU 的数据传输,使用锁页内存可以提高数据复制的速度。
drop_last
- 建议:当批量大小不能整除数据集大小时,考虑是否丢弃最后一个不完整的批量。如果模型训练对每个批量的大小敏感,可以设置
drop_last=True
。
timeout
- 建议:除非遇到数据加载超时的问题,否则通常不需要调整此参数。如果确实需要,可以适当增加超时时间。
worker_init_fn
- 建议:当数据集依赖于随机种子初始化(如数据增强)时,可以通过
worker_init_fn
设置不同的随机种子,确保不同进程之间的随机性。除此之外该函数还可以完成初始化资源操作,例如打开文件描述符、连接数据库等。
prefetch_factor
- 建议:默认情况下,每个工作进程会提前加载 2 个批量的数据。如果发现数据加载成为瓶颈,可以尝试增加
prefetch_factor
,但要避免过度增加导致内存溢出。
persistent_workers
- 建议:对于大型数据集或复杂的预处理流程,设置
persistent_workers=True
可以保持工作进程的活跃状态,减少每次重启的开销。
pin_memory_device
- 建议:如果使用多个 GPU,可以通过指定
pin_memory_device
来优化数据传输。例如,如果你有一个主 GPU 和一个辅助 GPU,可以将pin_memory_device
设置为主 GPU 的设备名,以加速数据传输到其他 GPU。
特别的,针对num_workers大于1的情况,DataLoader类给出了如下的数据处理流程:
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
- 主进程(Main Process)
主进程负责创建 DataLoader 实例,并启动数据加载流程。
主进程将索引放入 index_queue,这些索引由 Sampler 生成,指示哪些数据项需要被加载。
- 索引队列(Index Queue)
index_queue 是一个队列,用于存储从 Sampler 生成的索引。
主进程将这些索引放入 index_queue,工作进程从中取出索引。
- 工作进程(Worker Processes)
工作进程从 index_queue 中取出索引,并根据这些索引从数据集中加载数据。
加载的数据被放入 worker_result_queue,准备传递给主进程。
- 工作结果队列(Worker Result Queue)
worker_result_queue 是一个队列,用于存储工作进程加载的数据。
工作进程将加载的数据放入 worker_result_queue,主进程中的 pin_memory_thread 从中取出数据。
- 主进程的 Pin Memory 线程(Pin Memory Thread of Main Process)
如果 pin_memory=True,主进程会启动一个 pin_memory_thread。
pin_memory_thread 从 worker_result_queue 中取出数据,并将数据转换为带有 Pin Memory 的张量,以便更快地传输到 GPU。
- 数据队列(Data Queue)
data_queue 是一个队列,用于存储处理好的数据。
pin_memory_thread 将处理好的数据放入 data_queue,准备最终输出。
- 数据输出(Data Output)
最终,数据从 data_queue 中取出,作为 DataLoader 的输出返回给调用端。
模型使用数据
DataLoader 加载的数据,最终会输入给模型,代码段8-4:
def train(dataloader, model, loss_fn , optimizer):
model.train()
for batch,(X,y) in enumerate(dataloader):
X,y = X.to(device), y.to(device)
值得注意的是,代码段8-4 中第4行的输出内容,取决于Dataset 类中__getitem__函数的实现,具体的,代码段8-5。
import torch
class Dataset(torch.utils.data.Dataset):
def __len__(self):
return 10000
def __getitem__(self, any):
return torch.tensor([[1, 2], [3, 4]])
#return torch.empty(0)
if __name__ == '__main__':
dl = torch.utils.data.DataLoader(
Dataset(),
batch_size=128,
num_workers=1
#drop_last=True # 显式丢弃最后一个不完整批次
)
it = iter(dl)
for i, x in enumerate(it):
print(i, x.shape)
第8行的return 决定了最终的数据组织形式。
另外,代码3-4行,意味着数据集总共有10000 个样本。代码13行,设置batch_size = 128,总样本数 / 批量大小 = 10000 / 128 ≈ 78.125。这意味着完整批次数为 78 个(每批 128 个样本),最后一个不完整批次包含 10000 – 78 * 128 = 16 个样本。如果设置 drop_last=True,则会丢弃最后一个不完整批次。
工程实践
在实际工程中,需要处理的数据集动辄TB甚至PB级别,需要考虑硬件资源限制、效率等问题。
流式加载
我们以代码段8-6为例,代码逻辑加载磁盘文件到内存,然后根据索引键访问数据。
class LargeTextDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=128):
self.file_path = file_path
self.tokenizer = tokenizer
self.max_length = max_length
self.data = []
self.load_data()
def load_data(self):
with open(self.file_path, 'r', encoding='utf-8') as f:
for line in f:
tokens = self.tokenizer.encode(line.strip())[:self.max_length]
if len(tokens) > 1:
self.data.append(tokens)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return torch.tensor(self.data[idx], dtype=torch.long)
代码段8-6可能存在的问题:
- self.data=[] 将所有数据加载到内存中,假设每条文本平均500 tokens,1亿条数据需要约 `100,000,000 * 500 * 4 bytes ≈ 200GB` 内存(以float32为例)
当内存不够时,可用的解决方案是使用小批量分批处理,或者使用流式方式处理。
代码段8-7 使用迭代器进行懒加载数据,模拟流式处理数据。
class StreamTextDataset(IterableDataset):
def __init__(self, file_path, tokenizer, max_length = 128, shard_id =0, num_shards = 4):
self.file_path = file_path
self.tokenizer = tokenizer
self.max_length = max_length
self.shard_id = shard_id
self.num_shards = num_shards
def _line_generator(self):
with open(self.file_path, 'r', encoding='utf-8') as f:
for line in f:
yield line.strip()
def __iter__(self):
generator = self._line_generator()
for idx, line in enumerate(generator):
tokens = self.tokenizer.encode(line.strip())[:self.max_length]
if len(tokens)>=2:
yield torch.tensor(tokens, dtype= torch.long)
代码段8-7,类 StreamTextDataset继承自IterableDataset,后者用于从数据源创建迭代式的数据集。
执行代码段8-6和8-7,后者内存资源占用几乎可以忽略。
- 代码段8-6 还存在性能问题,原因是只有一个主进程在操作数据文件,自然而言的可以想到使用多进程方式,见代码段8-8。
class StreamTextDataset(IterableDataset):
def __init__(self, file_path, tokenizer, max_length = 128, shard_id =0, num_shards = 4):
self.file_path = file_path
self.tokenizer = tokenizer
self.max_length = max_length
self.shard_id = shard_id
self.num_shards = num_shards
def _line_generator(self):
with open(self.file_path, 'r', encoding='utf-8') as f:
for line in f:
yield line.strip()
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
num_workers = worker_info.num_workers if worker_info else 1
worker_id = worker_info.id if worker_info else 0
generator = self._line_generator()
for idx, line in enumerate(generator):
if idx % num_workers != worker_id:
continue
tokens = self.tokenizer.encode(line.strip())[:self.max_length]
if len(tokens)>=2:
yield torch.tensor(tokens, dtype= torch.long)
代码段8-8,使用多个进程来处理数据。这个代码还有要改进的地方,这多个进程会重复读取全量文件数据,改进的方法是使用内存映射,然后不同的进程读取不同的数据块,这里不再赘述。
分片预取
通过预取机制,可以在当前批次的数据正在被处理的同时,预先加载下一个批次的数据。这样,在开始处理下一个批次的数据时,所需的数据已经准备就绪,从而减少等待时间提高整体吞吐量。
关键参数矩阵
参数 | 默认值 | 作用域 | 优化方向 |
num_workers | 1 | 进程级 | CPU 并行度 |
prefetch_factor | 2 | 每个 worker | 流水线深度 |
persistent_workers | FALSE | 全局 | worker 生命周期管理 |
batch_size | 1 | 数据级 | GPU 吞吐量 |
shuffle | FALSE | 数据分布 | 内存访问模式 |
以上参数协同优化策略参考,
优先级顺序:
- 先确定batch_size(基于显存限制和模型需求)。现在很多工作都采用了动态批次调整策略,即在训练过程中逐渐增加批次大小,最终达到百万级别。例如,GPT-3 的批次大小从 32K 个词元逐渐增加到 3.2M个词元,这是因为较小的批次对应反向传播的频率更高,训练早期可以使用少量的数据让模型的损失尽快下降;而较大的批次可以在后期让模型的损失下降地更加稳定,使模型更好地收敛。
- 再调整num_workers(基于CPU资源和数据复杂度)。
- 最后优化prefetch_factor(平衡内存和加载速度)。
硬件适配:
- 高CPU低GPU:增大num_workers和prefetch_factor,减少GPU空闲。
- 高GPU低CPU:适当降低num_workers,避免CPU成为瓶颈。
辅助参数:
启用pin_memory=True加速数据从CPU到GPU的传输。
使用persistent_workers=True避免重复创建进程的开销(适用于多epoch训练)。
使用GPU加速
DALI英伟达推出的开源GPU加速数据加载与预处理库,通过将数据预处理任务从CPU卸载到GPU,显著提升训练和推理效率,主要针对图像、音频和视频数据。