网易首页 > 网易号 > 正文 申请入驻

PyTorch 分布式训练底层原理与 DDP 实战指南

0
分享至

深度学习模型参数量和训练数据集的爆炸式增长,以 Llama 3.1 为例:4050 亿参数、15.6 万亿 token 的训练量,如果仅靠单 GPU可能需要数百年才能跑完,或者根本无法加载模型。

并行计算(Parallelism)通过将训练任务分发到多个 GPU(单机多卡或多机多卡),并利用通信原语同步状态,能让训练过程变得可控且高效。



本文讲详细探讨Pytorch的数据并行(Data Parallelism)

扩展算力的两种路径

扩展训练规模无非两种方式:纵向扩展(Vertical Scaling)横向扩展(Horizontal Scaling)

纵向扩展:



简单粗暴地升级硬件。比如把 10GB 显存的显卡换成 30GB 的。这种方式不需要改动代码,原本跑不起来的模型现在能跑了或者可以调大 batch size加快训练速度。

横向扩展:



增加机器数量。比如增加 7 台同配置(10GB)的机器,通过网络互联每台机器可以挂载单卡或多卡。这种方式需要代码层面的适配,利用 PyTorch 的分布式模块进行通信。

数据并行 vs 模型并行

数据并行 (Data Parallelism):



当模型本身能塞进单张 GPU,但数据量太大时,我们可以将模型复制到所有 GPU 上,将数据切分(Split),每个 GPU 并行处理一部分数据,在反向传播时同步梯度。

模型并行 (Model Parallelism):



当模型大到单张 GPU 放不下时使用。将模型切分成不同部分,分配给不同 GPU。每个 GPU 负责计算前向/后向传播中的一部分层。

实际超大模型训练中,通常是两者的混合。

前置概念:梯度累积

在讲 DDP 之前,先回顾一个基础技巧:梯度累积(Gradient Accumulation)。PyTorch 的设计中,loss.backward() 计算出的梯度默认是累加的,而非覆盖。

import torch
# Let us define a tensor with requires_grad = True
x = torch.tensor(4.0, requires_grad=True)
# Creating a function y=x^2
y = x*x
# Calculating dy/dx
y.backward(retain_graph=True)
# retain_graph = True because pytorch automatically removes the computation
# graph and intermediate tensors once backward is called to save memory
# If we want to call backward again, we need to tell pytorch not to delete
# the computation graph and intermediate tensors
print(f"Gradient of y w.r.t x after the first backward: {x.grad}")
# Output: 8.0 as dy/dx = 2*x = 2*4
# Now let us try to call backward again
y.backward()
print(f"Gradient of y w.r.t x after the second backward: {x.grad}")
# Output = 16 because 8 from the previous backward is added up here

利用这个特性,当大 Batch Size 导致 OOM 时,可以将其切分为多个 Mini-batch,连续执行多次 backward() 累积梯度,最后执行一次 optimizer.step()。这是单卡解决显存瓶颈的常用手段。

分布式数据并行 (DDP) 工作流

PyTorch 的 DistributedDataParallel (DDP) 是实现数据并行的核心模块,基于 c10d 的ProcessGroup进行通信,每个进程(Process)通常控制一个 GPU 并持有一个模型副本。

DDP 的标准执行流程如下:

  1. 初始化 ProcessGroup:建立进程间的通信握手。
  2. 广播权重(Broadcast):训练开始时,Rank 0 节点的模型权重被广播到所有其他节点,确保初始状态一致。
  3. 前向反向传播:每个节点在自己的数据子集上独立计算。
  4. 梯度归约(All-Reduce):各节点的梯度被汇聚(求和或平均),然后同步回所有节点。
  5. 参数更新:各节点使用同步后的梯度独立更新权重。



集合通信原语 (Collective Communication Primitives)

分布式训练中,点对点(Point-to-Point)通信效率太低。假设要把 5MB 参数发给 5 个节点,逐个发送会导致耗时线性增长。集合通信(Collective Communication)利用拓扑结构(如树状、环状)并行传输,是高性能计算的基石。



DDP 主要依赖两个原语:

  • Broadcast: 将数据从一个节点(通常是 Rank 0)分发给其余所有节点,用于初始化权重。
  • Reduce / All-Reduce: 将所有节点的数据汇总,DDP 中用于梯度同步。

故障恢复 (Failovers) 与 Checkpointing

在分布式环境中,节点故障是常态,一旦某个 Rank 挂了,重启整个集群从零训练成本过高。必须使用Checkpointing:定期将模型权重写入共享存储(Shared Disk)。



这样恢复训练时,可以从最新的 Checkpoint 加载权重。这里需要注意的是只允许 Rank 0 写入 Checkpoint,避免多进程同时写文件造成损坏。

代码实战:从 CPU 到 GPU

下面通过代码演示 DDP 的完整流程。先以 CPU 模拟多进程环境,再迁移到 GPU。

基础组件:Dataset 与 Mode

import torch
import torch.nn as nn
from torch.utils.data import Dataset
class SimpleDataset(Dataset):
def __init__(self, size=100):
self.size = size
self.data = torch.randn(size, 10) # generate 100 samples each having dimension 10
self.labels = torch.randn(size, 1) # generate labels corresponding to each sample
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
return self.fc2(x)

初始化环境

setup 函数负责建立进程组。

import os
import torch.distributed as dist
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost' # IP of the "master" node
os.environ['MASTER_PORT'] = '12355' # Port of communication
# If we have 4 processes, each process independently calls setup() with
# its own rank
dist.init_process_group(backend='gloo', rank=rank, world_size=world_size)
# 'gloo' is the collective communication backend used for CPU
# nccl is used for CUDA
print(f"\n{'='*60}")
print(f"[Rank {rank}] Process initialized!")
print(f"[Rank {rank}] Backend: {dist.get_backend()}")
print(f"[Rank {rank}] World Size: {dist.get_world_size()}")
print(f"{'='*60}\n")

数据切分:DistributedSampler

这是数据并行的关键。DistributedSampler 会根据 Rank ID 自动切分数据集,确保每个进程拿到不重叠的数据子集。

注意:必须在每个 epoch 开始前调用 set_epoch(epoch),否则每个 epoch 的数据切分顺序将完全一样导致模型只见过部分数据,影响泛化能力。

# Example usage conceptual
# Create DistributedSampler for each rank
# sampler_rank0 = DistributedSampler(dataset, num_replicas=4, rank=0)
# ...
# Loop
for epoch in range(num_epochs):
train_sampler.set_epoch(epoch) # Different shuffle each epoch
for batch in train_loader:
# Training code

核心训练 Loop (Worker)

from torch.nn.parallel import DistributedDataParallel as DDP
def print_separator(rank, message):
print(f"\n[Rank {rank}] {'-'*40}")
print(f"[Rank {rank}] {message}")
print(f"[Rank {rank}] {'-'*40}")
def train_worker(rank, world_size, num_epochs=2, batch_size=8):
# setup the distributed environment
setup(rank, world_size)
model = SimpleModel()
# wrap the model with DDP
# This is where weights are synchronized across ranks
ddp_model = DDP(model)
dataset = SimpleDataset(size=32)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
criterion = nn.MSELoss()
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # Ensure different shuffling per epoch
epoch_loss = 0.0
for batch_idx, (data, target) in enumerate(dataloader):
optimizer.zero_grad()
output = ddp_model(data)
loss = criterion(output, target)
# Backward pass - THIS IS WHERE DDP MAGIC HAPPENS
loss.backward()
# Gradients are synchronized (All-Reduce) here automatically
optimizer.step()
epoch_loss += loss.item()
dist.destroy_process_group()
print(f"[Rank {rank}] Training completed and cleaned up!\n")

验证集不能像训练集那样随意。有两种处理策略:

Rank 0 独占:只在 Rank 0 上跑全量验证集。这个方法比较简单但会造成其他 GPU 等待所以效率低。

分布式验证:各 Rank 跑一部分最后聚合 Loss 和样本数,一般都会用这个方案。

def validate(model, val_loader, criterion, rank, epoch):
model.eval()
val_loss = 0.0
num_samples = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(val_loader):
output = model(data)
loss = criterion(output, target)
val_loss += loss.item() * data.size(0)
num_samples += data.size(0)
# Aggregate metrics across all processes
total_loss_tensor = torch.tensor([val_loss])
total_samples_tensor = torch.tensor([num_samples])
# Sum across all processes
dist.all_reduce(total_loss_tensor, op=dist.ReduceOp.SUM)
dist.all_reduce(total_samples_tensor, op=dist.ReduceOp.SUM)
global_avg_loss = total_loss_tensor.item() / total_samples_tensor.item()
return global_avg_loss

启动进程,CPU 演示通常用 mp.spawn:

def main():
world_size = 2
mp.spawn(
demo_worker,
args=(world_size,),
nprocs=world_size,
join=True
)

生产环境迁移:CUDA 与 Torchrun

在实际 GPU 训练中,需要做 5 点改动:

  1. Backend: gloo -> nccl (NVIDIA 专用,速度最快)。
  2. Model: model.cuda(rank)。
  3. DDP Wrapper: DDP(model, device_ids=[rank])。
  4. Data: data.cuda(rank)。
  5. Device: torch.cuda.set_device(rank)。

启动方式不再推荐使用mp.spawn,而是直接使用Torch自带的CLI工具torchrun,它能自动处理环境变量(RANK, WORLD_SIZE, LOCAL_RANK)并支持故障重启。

# Code expects env vars
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
demo_worker(rank, world_size)

运行命令:

torchrun --nproc_per_node=4 train.py

性能优化:Bucketing 与 Overlap

PyTorch DDP 之所以快,不仅仅是因为分了数据,更在于它对通信的优化。

通信与计算重叠 (Communication Overlap)

我们可能认为要等所有层的梯度算完再同步,但这会导致 GPU 出现长时间空闲(Wait)。所以DDP 的做法是一旦某层的梯度算出来,如果不被依赖,就立刻发送同步



如上图,Layer 4 的梯度一算好,在计算 Layer 3 的同时,Rank 0 已经在处理 Layer 4 的同步了。

分桶 (Bucketing)

为了避免频繁发送小包导致网络拥塞,DDP 会将梯度打包进 Bucket(默认 25MB)。

当一个 Bucket 被填满(例如包含 Layer 4, 5, 6 的梯度),就会触发一次 All-Reduce。这种批量处理大幅降低了通信开销。



这是一个为您准备的结尾总结,保持了之前设定的专业且行动导向的风格,同时也呼应了原作者关于“下一篇讲解模型并行”的预告:

总结

我们已经拆解了 PyTorch DDP 的核心运作机制:从底层的 ProcessGroup 通信握手,到梯度的 All-Reduce 同步,再到 Bucket 分桶与计算通信重叠的性能优化。掌握这些,你已经具备了将单卡代码低成本迁移到多卡集群的能力,不再受限于单机的训练时长。

https://avoid.overfit.cn/post/11d9f5d9b4fc4cd49cf1b8f97f09252f

作者:Trinanjan Mitra

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
韩寒的这场“豪赌”,赢得很彻底,他能分账多少钱?

韩寒的这场“豪赌”,赢得很彻底,他能分账多少钱?

八卦南风
2026-02-25 13:37:35
香港自杀女警手机中发现遗书

香港自杀女警手机中发现遗书

现代快报
2026-02-25 20:33:04
iPhone 18 Pro灵动岛缩小35%,将首发搭载基于台积电2nm工艺制造的A20 Pro芯片

iPhone 18 Pro灵动岛缩小35%,将首发搭载基于台积电2nm工艺制造的A20 Pro芯片

中国能源网
2026-02-24 12:01:06
德国总理没想到,落地中国才一天,访华前一个举动让自己口碑暴增

德国总理没想到,落地中国才一天,访华前一个举动让自己口碑暴增

郭夷包工头
2026-02-25 19:03:48
《太平年》在越南吵翻:同一段乱世,吴越选择纳土,越南选择独立

《太平年》在越南吵翻:同一段乱世,吴越选择纳土,越南选择独立

狐狸先森讲升学规划
2026-02-25 10:10:42
东契奇谈最后拒投三分:我知道自己是空位,但觉得距离远了点

东契奇谈最后拒投三分:我知道自己是空位,但觉得距离远了点

懂球帝
2026-02-25 15:12:18
升级版的仙人跳,比戴绿帽子还憋屈

升级版的仙人跳,比戴绿帽子还憋屈

霹雳炮
2026-02-24 22:53:34
金字塔碳14检测后,真相混乱到让学界集体“疯魔”

金字塔碳14检测后,真相混乱到让学界集体“疯魔”

混沌录
2026-02-25 22:30:19
世界第一女巨人来自中国安徽,穿78码的鞋子,一顿饭吃六碗炒面

世界第一女巨人来自中国安徽,穿78码的鞋子,一顿饭吃六碗炒面

不写散文诗
2026-02-25 21:02:18
美媒解读东契奇拒投绝杀后唇语:詹姆斯让我传球,我就传了

美媒解读东契奇拒投绝杀后唇语:詹姆斯让我传球,我就传了

懂球帝
2026-02-26 09:45:10
男演员长相多重要?把34岁黄景瑜和25岁陈飞宇对比,差距一目了然

男演员长相多重要?把34岁黄景瑜和25岁陈飞宇对比,差距一目了然

银河史记
2026-02-25 22:30:03
挪威前首相自杀未遂,命悬一线!其被指涉爱泼斯坦案,多处住所被突袭搜查!欧洲多国政要、王室成员被查

挪威前首相自杀未遂,命悬一线!其被指涉爱泼斯坦案,多处住所被突袭搜查!欧洲多国政要、王室成员被查

每日经济新闻
2026-02-25 17:15:06
痛心!广东英德1岁走失男童在报警人家附近鱼塘中被发现,已无生命体征

痛心!广东英德1岁走失男童在报警人家附近鱼塘中被发现,已无生命体征

封面新闻
2026-02-26 01:57:06
上海这晚,57岁周涛秒了30岁李雪琴,不愧是央视严选的国泰民安脸

上海这晚,57岁周涛秒了30岁李雪琴,不愧是央视严选的国泰民安脸

大铁猫娱乐
2026-02-08 00:10:03
封神!株洲司机最后1秒冲过收费站,收费员比他还疯,全网笑炸

封神!株洲司机最后1秒冲过收费站,收费员比他还疯,全网笑炸

观察鉴娱
2026-02-25 10:09:22
西湖大学打了谁的脸?外籍学生学费35万一年,国内学生仅6千元

西湖大学打了谁的脸?外籍学生学费35万一年,国内学生仅6千元

妍妍教育日记
2026-02-24 18:35:18
孩子第一天就轰动学校是啥感觉?网友:这孩子以后能成大事

孩子第一天就轰动学校是啥感觉?网友:这孩子以后能成大事

解读热点事件
2026-02-25 15:32:21
卡塞米罗表现下滑,卡里克或重用曼联新博格巴!小将希文改打后腰

卡塞米罗表现下滑,卡里克或重用曼联新博格巴!小将希文改打后腰

体坛关键帧
2026-02-26 09:51:39
33岁重庆女子命丧中灵山,遗体挂悬崖,目击者发声 丈夫行为引争议

33岁重庆女子命丧中灵山,遗体挂悬崖,目击者发声 丈夫行为引争议

小鹿姐姐情感说
2026-02-25 19:37:04
宋彬彬晚年回国道歉仍不被原谅,其父宋任穷也不愿提起她,为何

宋彬彬晚年回国道歉仍不被原谅,其父宋任穷也不愿提起她,为何

春秋砚
2026-02-24 12:25:08
2026-02-26 11:28:49
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1931文章数 1456关注度
往期回顾 全部

科技要闻

单季营收681亿净利429亿!英伟达再次炸裂

头条要闻

"花坛白骨案"2名凶手因4万元杀人埋尸 受害人儿子发声

头条要闻

"花坛白骨案"2名凶手因4万元杀人埋尸 受害人儿子发声

体育要闻

从排球少女到冰壶女神,她在米兰冬奥练出6块腹肌

娱乐要闻

尼格买提撒贝宁滑雪被偶遇 17年老友情

财经要闻

短剧市场风云突变!有人投百万赔得精光

汽车要闻

雷克萨斯ES双色特别版上市 售30.79万元起

态度原创

时尚
健康
艺术
房产
军事航空

伦敦时装周|2026秋冬流行趋势早知道

转头就晕的耳石症,能开车上班吗?

艺术要闻

谁能认出这幅14字草书的真正作者?

房产要闻

海南楼市春节热销地图曝光!三亚、陵水又杀疯了!

军事要闻

美政府给新伊核协议设限内容遭披露

无障碍浏览 进入关怀版