网易首页 > 网易号 > 正文 申请入驻

JAX性能优化实战:7个变换让TPU/GPU吃满算力

0
分享至

JAX跑得快的技巧其实很简单:通过组合变换让XLA能看到大块连续的计算,比如说批处理、融合、分片,让每一步在单设备或多设备同步时都像一个干净的kernel。

我们今天就来总结7个能够提高运行速度的JAX变换组合



1、 jit 优先,形状稳定

jit对函数做一次追踪后XLA负责融合算子,形状稳定、无副作用时,Python处理的开销就被分摊掉,可以提高运行速度。

形状创建和静态参数要么挪到step外部,要么显式标记为static。donate_argnums能让JAX复用缓冲区,省掉不必要的内存拷贝。step之间保持dtype和shape一致,trace结果才能被缓存下来。

import jax, jax.numpy as jnp
@jax.jit(donate_argnums=(0,))
def sgd_step(params, batch, lr):
x, y = batch
def loss_fn(p):
preds = model_apply(p, x) # pure function
return jnp.mean((preds - y) ** 2)
grads = jax.grad(loss_fn)(params)
return jax.tree_map(lambda p, g: p - lr * g, params, grads)

每个(shape, dtype, static-arg)组合只追踪一次。频繁retrace多半是输入shape在变,或者Python逻辑泄漏进了计算图。

2、vmap替换Python循环

vmap在leading axis上做向量化,XLA直接把batch融进kernel。for循环没了设备launch就少了,内存访问也更连续。

# per-example loss
def example_loss(params, x, y):
pred = model_apply(params, x)
return jnp.mean((pred - y) ** 2)
# batch it without writing loops
batched_loss = jax.vmap(example_loss, in_axes=(None, 0, 0)) # params broadcasted

嵌套vmap可以搞2D batch,比如time × batch,只要别超HBM容量。vmap适合做内层微批处理,比如ensemble或MC sampling这类场景,外层维度留给分片。

3、长循环的融合利器Scan

RNN、展开解码、迭代求解器,这些场景用scan比Python循环快。scan只编译一次循环体跑在XLA的while-loop里,Python开销基本为0,融合和内存复用也更激进。

from jax import lax
def rnn_cell(carry, x):
h = carry
h = jnp.tanh(W_hh @ h + W_xh @ x + b)
y = W_hy @ h
return h, y # (carry, output)
def rnn_forward(h0, xs):
hT, ys = lax.scan(rnn_cell, h0, xs) # xs: [T, B, D]
return hT, ys

循环状态用carry传递,body保持小而纯净,要注意保持形状不要变,比如:序列模型、diffusion step循环、定点迭代、beam解码(形状稳定时)都适用。

4、remat可以用计算换内存

批次大了TPU/GPU的FLOP利用率往往更高。remat(也叫checkpoint)会丢掉部分中间激活,反向时重算这样峰值显存下来batch就能开的更大。

from jax import remat
def block(params, x):
x = jax.nn.gelu(x @ params['w1'])
x = x @ params['w2']
return x
fast_block = remat(block) # checkpointed
@jax.jit
def forward(params, x):
for _ in range(6):
x = x + fast_block(params, x)
return x

只包最重的子块就行,比如attention加MLP那几层。同时配合vmap或分片,全局batch能再往上拉。不过需要一些额外FLOPs,但如果换来1.3到2倍的batch increase,wall-clock往往更短。

5、pmap单机多卡数据并行

pmap把函数复制到单主机的多个设备上(8卡工作站、单节点8核TPU),梯度可以自动all-reduce,并且每设备只编译一次。

from jax import pmap, lax
@pmap(axis_name='d')
def train_step(params, batch, lr):
x, y = batch # each device sees [local_B, ...]
def loss_fn(p):
pred = model_apply(p, x)
loss = jnp.mean((pred - y) ** 2)
return loss
loss, grads = jax.value_and_grad(loss_fn)(params)
loss = lax.pmean(loss, axis_name='d')
grads = lax.pmean(grads, axis_name='d')
params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return params, loss

batch在leading axis分片,lax.pmean聚合loss和grads。单机场景下pmap简单可靠。跨主机扩展或者想做张量级细粒度分片可以成换pjit。

6、pjit+ 命名分片:SPMD并行

pjit编译出单一SPMD程序可以跨设备跨主机运行。用mesh和PartitionSpec描述数组怎么切,JAX处理collective通信,这样数据并行、张量并行、混合并行都能做。

import jax
from jax.sharding import Mesh, PartitionSpec as P
import numpy as np
devices = np.array(jax.devices()).reshape(2, 4) # 2 × 4 mesh (dp × mp)
mesh = Mesh(devices, ('dp', 'mp'))
@jax.jit # jit is optional when using pjit; shown when composing
def model_apply_sharded(params, x):
return model_apply(params, x)
from jax.experimental.pjit import pjit
with mesh:
in_shard = (P('mp',), P('dp',)) # example; tailor to your shapes
out_shard = P('dp',) # e.g., shard batch across dp
step = pjit(model_apply_sharded,
in_shardings=(P('mp',), P('dp',)),
out_shardings=out_shard)
y = step(params_sharded, x_sharded)

一般都是batch轴走dp,大矩阵维度(hidden size、heads)走mp。分片数需要跟设备拓扑对齐,跨主机流量才少。

7、value_and_grad的正确堆叠方式

规范写法是jit(value_and_grad(loss, has_aux=True)),外面可以再套一层pmap或pjit。这样forward只跑一遍metrics留在aux里带出来。

def loss_with_aux(params, batch):
x, y = batch
pred = model_apply(params, x)
loss = jnp.mean((pred - y) ** 2)
aux = {'mse': loss, 'mean_pred': jnp.mean(pred)}
return loss, aux
@jax.jit
def train_step(params, opt_state, batch, lr):
(loss, aux), grads = jax.value_and_grad(loss_with_aux, has_aux=True)(params, batch)
updates, opt_state = optimizer_update(grads, opt_state, params, lr)
params = optax_apply(updates, params)
return params, opt_state, loss, aux

value_and_grad放jit里面,JAX会把forward和backward一起stage。返回(loss, aux)日志指标不用再跑一遍forward。

这套组合很灵活:vmap做微批次,scan跑时序循环,外面套pmap或pjit,donate_argnums标上buffer。

总结

变长序列pad加mask,shape稳定是前提条件。traced代码里不要添加Python随机性,比如PRNG key要在外面split好。矩阵乘用bfloat16,这样数值稳定性也够用,吞吐量在TPU/GPU上表现的也很好。性能profile要重点看warm-up之后的tokens/sec或samples/sec。日志只看标量aux metrics就行,每step把大数组传回host是性能杀手。

JAX的性能不是黑盒:jit + shape可以稳定打底,vmap做batch,scan融合循环,remat回收显存,pmap或pjit做扩展,value_and_grad(..., has_aux=True)让每一步只跑一次forward一次backward。

https://avoid.overfit.cn/post/84e4e28e3ca8473488a0e9248d1ec51b

作者:Nexumo

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
真相大白!11级新疆班学生发声,辟谣林傲霏是同学,班级名单曝光

真相大白!11级新疆班学生发声,辟谣林傲霏是同学,班级名单曝光

古希腊掌管月桂的神
2026-01-14 16:28:57
钱都让陈小群赚走了

钱都让陈小群赚走了

独孤老赵的笔记
2026-01-14 16:10:34
DeepSeek梁文锋喊话罗永浩:靠嘴年入过亿,为啥非得做科技?

DeepSeek梁文锋喊话罗永浩:靠嘴年入过亿,为啥非得做科技?

雷科技
2026-01-14 15:31:09
员工餐补贴大概率是要到头了!

员工餐补贴大概率是要到头了!

达文西看世界
2026-01-14 14:24:40
广东宏远拒绝输球!全力击败上海男篮,奎因对位张镇麟,央视直播

广东宏远拒绝输球!全力击败上海男篮,奎因对位张镇麟,央视直播

体坛瞎白话
2026-01-14 15:04:40
离婚才3个月,高调谈论再婚的李亚鹏,没给海哈金喜留一丝体面

离婚才3个月,高调谈论再婚的李亚鹏,没给海哈金喜留一丝体面

人间无味啊
2026-01-14 00:15:59
知名演员保剑锋方发声:法庭见!

知名演员保剑锋方发声:法庭见!

南方都市报
2026-01-14 19:37:55
贪污上亿、假慈善?被实名举报的韩红,明白她终身不嫁的原因

贪污上亿、假慈善?被实名举报的韩红,明白她终身不嫁的原因

泠泠说史
2025-11-27 18:18:02
A股:周三夜晚传来5大核弹级消息!明天很可能会迎更大级别大行情?

A股:周三夜晚传来5大核弹级消息!明天很可能会迎更大级别大行情?

股市皆大事
2026-01-14 18:26:59
动手前通知中美,不想被中美同时误判,普京遭斩首未遂,报复太狠

动手前通知中美,不想被中美同时误判,普京遭斩首未遂,报复太狠

科普100克克
2026-01-14 14:28:07
伊朗抗议血腥镇压:数百青年头颈中枪殒命,23岁女大学生遗体被扣

伊朗抗议血腥镇压:数百青年头颈中枪殒命,23岁女大学生遗体被扣

译言
2026-01-13 11:00:50
美国人意识到,贸易战之后,不会再有中国外的大规模工业化国家了

美国人意识到,贸易战之后,不会再有中国外的大规模工业化国家了

沧海旅行家
2026-01-14 14:44:50
王石田朴珺运动照流出!网友:这老头,太不容易了……

王石田朴珺运动照流出!网友:这老头,太不容易了……

麦杰逊
2026-01-13 12:09:35
种菜种到了政府储备用地上?深圳沙河街道回应

种菜种到了政府储备用地上?深圳沙河街道回应

深圳晚报
2026-01-14 19:31:45
刘嘉玲豪宅曝光,墙壁成亮点,她吃啥狗吃啥,多次直言佩服刘德华

刘嘉玲豪宅曝光,墙壁成亮点,她吃啥狗吃啥,多次直言佩服刘德华

查尔菲的笔记
2026-01-14 16:31:11
“耄耋耆耈”这四个字你认识吗?是什么意思呢?读错小心闹笑话

“耄耋耆耈”这四个字你认识吗?是什么意思呢?读错小心闹笑话

长风文史
2026-01-14 11:40:58
瓦良格号送到中国后有多震撼?专家刮掉表面的锈迹:钢材品质极佳

瓦良格号送到中国后有多震撼?专家刮掉表面的锈迹:钢材品质极佳

古书记史
2026-01-06 16:31:56
重庆合川“摇人按猪”女孩粉丝破190万!60秒以上视频广告报价2400元,商标被多方申请注册,网友提议→

重庆合川“摇人按猪”女孩粉丝破190万!60秒以上视频广告报价2400元,商标被多方申请注册,网友提议→

封面新闻
2026-01-13 16:17:06
又揪出来一个巨贪,金额高达9.7亿,首富夫人郝斌跨境逃亡失败了

又揪出来一个巨贪,金额高达9.7亿,首富夫人郝斌跨境逃亡失败了

南权先生
2026-01-14 16:49:58
“网红干部”贺娇龙因意外事故去世,年仅47岁

“网红干部”贺娇龙因意外事故去世,年仅47岁

韩小娱
2026-01-14 10:20:40
2026-01-14 21:28:49
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1890文章数 1443关注度
往期回顾 全部

科技要闻

携程因涉嫌垄断被市场监管总局调查

头条要闻

国企领导超83%赃款来自境外:钱藏在10个国家和地区

头条要闻

国企领导超83%赃款来自境外:钱藏在10个国家和地区

体育要闻

你是个好球员,我们就拿你交易吧

娱乐要闻

何晴去世30天,许亚军终于发声

财经要闻

涉嫌垄断!市场监管总局对携程立案调查

汽车要闻

曝Model Y或降到20万以内!

态度原创

手机
艺术
本地
健康
教育

手机要闻

iPhone Fold模具首曝:闭合如mini、展开似iPad,设计神似OPPO!

艺术要闻

八大山人『山水花鸟册』

本地新闻

邵阳公益诉讼检察主题曲:《守望星》

血常规3项异常,是身体警报!

教育要闻

多图直击:北京各小学让孩子在“乐”中“考”出未来素养

无障碍浏览 进入关怀版