网易首页 > 网易号 > 正文 申请入驻

Mosaic:面向超长序列的多GPU注意力分片方案

0
分享至

Transformer的"二次方注意力瓶颈"的问题是老生常谈了。这个瓶颈到底卡在哪实际工程里怎么绕过去?本文从一个具体问题出发,介绍Mosaic这套多轴注意力分片方案的设计思路。



注意力的内存困境

注意力机制的计算公式:

Attention(Q, K, V) = softmax(QKᵀ / √d) × V

问题出在QKᵀ这个矩阵上,它的形状是 (序列长度 × 序列长度)。

拿150,000个token的序列算一下:

Memory = 150,000² × 4 bytes = 90 billion bytes ≈ 84 GB

这只是注意力权重本身的开销,而且还是单层、单头。A100的显存上限是80GB,放不下就是放不下。

现有方案的局限

FlashAttention它通过分块计算,不需要把完整的注意力矩阵实例化出来,内存复杂度从O(n²)降到O(n)。单卡场景下效果很好,但问题是整个序列还是得塞进同一张GPU。

Ring Attention换了个思路:把序列切片分到多张GPU上,每张卡持有一部分Q,K和V在GPU之间像传令牌一样轮转,一维序列处理起来是很不错的。

但是多维怎么办?

比如处理表格数据的Transformer,输入张量形状是 (batch, rows, features, embed)。模型需要在不同维度上做注意力:features维度只有5个token,rows维度却有150,000个。前者单卡轻松搞定,后者则必须分片。

现有的库都没法干净地处理这种多轴场景。手写的话,每个轴要单独写分片逻辑,进程组管理、张量reshape全得自己来。代码会变得很脏。

Mosaic的设计

Mosaic本质上是个协调层,负责把不同的注意力轴路由到合适的计算后端:

import mosaic
# Small axis: run locally
feature_attn = mosaic.MultiAxisAttention(
embed_dim=96,
num_heads=4,
attention_axis=2, # features dimension
backend="local" # no communication needed
)
# Large axis: shard across GPUs
row_attn = mosaic.MultiAxisAttention(
embed_dim=96,
num_heads=4,
attention_axis=1, # rows dimension
backend="ring" # ring attention across GPUs
)

底层Mosaic会自动处理轴的置换、QKV投影前的reshape、后端分发、以及计算完成后张量形状的还原。模型代码保持清晰,分布式的复杂性被封装掉了。

Ring Attention的工作机制

核心思想其实很直接:不需要同时持有全部的K和V。可以分批计算注意力分数,逐步累积,最后再做归一化。

比如说4张GPU的情况下流程是这样的:

Initial state:
GPU 0: Q₀, K₀, V₀
GPU 1: Q₁, K₁, V₁
GPU 2: Q₂, K₂, V₂
GPU 3: Q₃, K₃, V₃
Step 1: Each GPU computes attention with its local K, V
GPU 0: score₀₀ = Q₀ @ K₀ᵀ
...
Step 2: Pass K, V to the next GPU in the ring
GPU 0 receives K₃, V₃ from GPU 3
GPU 0 sends K₀, V₀ to GPU 1
Step 3: Compute attention with received K, V
GPU 0: score₀₃ = Q₀ @ K₃ᵀ
Accumulate with score₀₀
Repeat for all chunks...
Final: Each GPU has complete attention output for its Q chunk

单卡内存占用变成O(n²/p),p是GPU数量。8张卡的话内存需求直接砍到1/8。150k序列从84GB降到约10GB每卡。

Mesh2D:更激进的分片

序列特别长的时候Ring Attention的线性分片可能还不够,这时候可以用Mesh2D把Q和K都切分了:

4 GPUs arranged in 2×2 mesh:
K₀ K₁
┌──────┬──────┐
Q₀ │GPU 0 │GPU 1 │
├──────┼──────┤
Q₁ │GPU 2 │GPU 3 │
└──────┴──────┘
Each GPU computes one tile of QKᵀ

内存复杂度降到O(n²/p²)。64张卡组成8×8网格时,每卡内存需求下降64倍。

attn = mosaic.MultiAxisAttention(
embed_dim=128,
num_heads=8,
attention_axis=1,
backend="mesh2d",
mesh_shape=(8, 8)
)

感知集群拓扑的组合策略

在实际部署环境里,不同GPU之间的通信带宽差异很大。节点内GPU走NVLink能到900 GB/s,跨节点通过InfiniBand通常只有200 GB/s左右。

ComposedAttention就是针对这种拓扑特征设计的:

# 4 nodes × 8 GPUs = 32 total
composed = mosaic.ComposedAttention(
mesh_shape=(4, 8), # (nodes, gpus_per_node)
head_parallel=True, # Split heads across nodes (slow link)
seq_parallel="ring" # Ring within nodes (fast link)
)

需要更精细控制的话,可以用 HierarchicalAttention:

hier = mosaic.HierarchicalAttention(
intra_node_size=8,
intra_node_strategy="local", # Compute locally within node
inter_node_strategy="ring" # Ring between node leaders
)

重通信走快链路轻通信才跨节点。



实现细节

整个库大约800行Python,核心代码如下:

class MultiAxisAttention(nn.Module):
def forward(self, x):
# 1. Move attention axis to seq position
x, inv_perm = self._permute_to_seq(x)
# 2. Flatten batch dims, project QKV
x = x.view(-1, seq_len, embed_dim)
qkv = self.qkv_proj(x).view(batch, seq, 3, heads, head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
# 3. Dispatch to backend
out = self._attn_fn(q, k, v) # local, ring, or mesh2d
# 4. Project output, restore shape
out = self.out_proj(out.transpose(1, 2).reshape(...))
return out.permute(inv_perm)

后端封装了现有的成熟实现:local后端调用F.scaled_dot_product_attention(也就是FlashAttention),ring后端用ring-flash-attn库的ring_flash_attn_func,mesh2d是自定义的all-gather加SDPA,所有的底层都跑的是FlashAttention内核。

所有后端统一用FlashAttention的融合GEMM+softmax实现。后端函数在初始化时就绑定好,前向传播不做分支判断。张量操作尽量用x.view()而不是x.reshape(),保持内存连续性。集合通信的目标张量预分配好,避免torch.cat的开销。模块级别做导入不在每次前向传播时产生import开销。

快速上手

安装:

pip install git+https://github.com/stprnvsh/mosaic.git
# With ring attention support
pip install flash-attn ring-flash-attn

单节点启动:

torchrun --nproc_per_node=4 train.py

多节点的话:

# Node 0
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=0 \
--master_addr=192.168.1.100 --master_port=29500 train.py
# Node 1
torchrun --nnodes=2 --nproc_per_node=8 --node_rank=1 \
--master_addr=192.168.1.100 --master_port=29500 train.py

训练脚本示例:

import mosaic
import torch.distributed as dist
dist.init_process_group("nccl")
ctx = mosaic.init(sp_size=dist.get_world_size())
model = MyModel().to(ctx.device)
# Data is pre-sharded: each GPU has seq_total / world_size tokens
x_local = load_my_shard()
out = model(x_local) # Communication handled by Mosaic

总结

最后,Mosaic不会自动并行化模型(这个用nnScaler),不管数据并行(PyTorch DDP/FSDP的事),也不处理模型分片(交给FSDP或Megatron)。

Mosaic专注于一件事:多轴注意力的分片路由,这套方案最初是给nanoTabPFN做的,一个表格数据Transformer。

这个模型要同时在rows(150k个)和features(5个)两个维度做注意力。标准Ring Attention对维度语义没有感知,它只认序列这个概念,分不清rows和features的区别。

所以Mosaic需求很明确:小轴本地算,大轴分布式算,轴的路由逻辑不能侵入模型代码,有兴趣的可以试试。

https://avoid.overfit.cn/post/791e0f30540e4d289a43d01d383e8ab2

作者:Pranav Sateesh

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
中国骨科,凛冬已逝

中国骨科,凛冬已逝

钛媒体APP
2026-01-16 16:08:09
我也不能怪父亲!男子回家过年没地方睡:这样的家庭,我该怎么办

我也不能怪父亲!男子回家过年没地方睡:这样的家庭,我该怎么办

唐小糖说情感
2026-01-16 16:54:54
被特朗普指责阻碍俄乌和平 泽连斯基反驳

被特朗普指责阻碍俄乌和平 泽连斯基反驳

看看新闻Knews
2026-01-16 23:21:02
罗马仕被曝正亏本清理库存充电宝:27000mAh型号售价45元

罗马仕被曝正亏本清理库存充电宝:27000mAh型号售价45元

IT之家
2026-01-16 11:00:18
贾国龙为什么维护“流泪劝谏”的华与华?

贾国龙为什么维护“流泪劝谏”的华与华?

界面新闻
2026-01-16 22:42:55
女人染上“性瘾”是一种怎样的体验?可能和你想象得不同

女人染上“性瘾”是一种怎样的体验?可能和你想象得不同

纸上的心语
2025-11-23 11:36:00
贾国龙果然是个大傻子

贾国龙果然是个大傻子

端木赐香三糊涂
2026-01-16 14:34:16
CBA积分榜最新排名出炉!辽篮掀翻新疆杀回第9:广东第3山西第10

CBA积分榜最新排名出炉!辽篮掀翻新疆杀回第9:广东第3山西第10

篮球快餐车
2026-01-17 05:28:06
狂降10℃!四川将迎来强降温

狂降10℃!四川将迎来强降温

鲁中晨报
2026-01-16 17:23:09
2026年1月存款利率大改!1万元存一年利息多少?我算透了说实在的

2026年1月存款利率大改!1万元存一年利息多少?我算透了说实在的

星辰宇的不羁
2026-01-11 10:01:20
王菲女儿窦靖童,这身打扮完全化身妈妈王菲了有种未来天后的感觉

王菲女儿窦靖童,这身打扮完全化身妈妈王菲了有种未来天后的感觉

手工制作阿歼
2026-01-17 06:20:51
离过年还有一个月,农村五大“怪现象”竟让人惊掉下巴!

离过年还有一个月,农村五大“怪现象”竟让人惊掉下巴!

特约前排观众
2026-01-17 00:15:05
爸爸去哪儿6个孩子现状:有人进国家队,有人出家,有人出国断联

爸爸去哪儿6个孩子现状:有人进国家队,有人出家,有人出国断联

小兔子的快乐
2026-01-15 22:35:50
何穗产后日常曝光:穿棉袄驱寒、咖啡不离手,自拍卖萌投喂陈伟霆

何穗产后日常曝光:穿棉袄驱寒、咖啡不离手,自拍卖萌投喂陈伟霆

无处不风景love
2026-01-15 22:14:10
随着比分定格2-3,阿联酋出局,亚洲杯4强诞生2席:日本和东南亚劲旅

随着比分定格2-3,阿联酋出局,亚洲杯4强诞生2席:日本和东南亚劲旅

侧身凌空斩
2026-01-17 02:08:36
冬窗花费破亿!曼城下一个引援目标曝光 曼联也想买他

冬窗花费破亿!曼城下一个引援目标曝光 曼联也想买他

球事百科吖
2026-01-17 05:10:30
账号被封禁,所有作品已清空!

账号被封禁,所有作品已清空!

艳姐的搞笑视频
2026-01-16 10:10:11
多个博主为西贝喊冤!不是输给罗永浩,而是输给是非不分的舆论场

多个博主为西贝喊冤!不是输给罗永浩,而是输给是非不分的舆论场

谈史论天地
2026-01-16 14:47:58
徐帆没想到,费心养了19年的女儿,也开始帮她保全婚姻的"体面"

徐帆没想到,费心养了19年的女儿,也开始帮她保全婚姻的"体面"

查尔菲的笔记
2026-01-16 19:25:52
笑死了!章泽天做节目采访刘嘉玲,没想到评论区句句都是梗!

笑死了!章泽天做节目采访刘嘉玲,没想到评论区句句都是梗!

八卦南风
2026-01-15 17:27:12
2026-01-17 07:11:00
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1892文章数 1443关注度
往期回顾 全部

科技要闻

贾国龙与罗永浩被禁言,微博CEO回应

头条要闻

罗永浩、贾国龙微博账号均被禁言

头条要闻

罗永浩、贾国龙微博账号均被禁言

体育要闻

全队身价=登贝莱,他们凭什么领跑法甲?

娱乐要闻

李湘翻车,早就有迹可循!

财经要闻

清流|酒店商家在携程和美团之间沦为炮灰

汽车要闻

方程豹品牌销量突破30万辆 2026年还将推出轿跑系列

态度原创

时尚
家居
亲子
本地
军事航空

今年冬天最时髦保暖的4组搭配,照着穿美出新高度!

家居要闻

岁月柔情 现代品质轻奢

亲子要闻

精神科医生:家长的“为你好”也可能对孩子造成创伤

本地新闻

云游内蒙|黄沙与碧波撞色,乌海天生会“混搭”

军事要闻

欧洲多国向格陵兰岛派遣军事人员 白宫回应

无障碍浏览 进入关怀版