网易首页 > 网易号 > 正文 申请入驻

JAX 核心特性详解:纯函数、JIT 编译、自动微分等十大必知概念

0
分享至

JAX 是 Google 和 NVIDIA 联合开发的高性能数值计算库,这两年 JAX 生态快速发展,周边工具链也日益完善了。如果你用过 NumPy 或 PyTorch,但还没接触过 JAX,这篇文章能帮助你快速上手。

围绕 JAX 已经涌现出一批好用的库:Flax用来搭神经网络,Optax处理梯度和优化,Equinox提供类似 PyTorch 的接口,Haiku则是简洁的函数式 API,Jraph用于图神经网络,RLax是强化学习工具库,Chex提供测试和调试工具,Orbax负责模型检查点和持久化。



纯函数是硬需求

JAX 对函数有个基本要求:必须是纯函数。这意味着函数不能有副作用,对同样的输入必须总是返回同样的输出。

这个约束来自函数式编程范式。JAX 内部做各种变换(编译、自动微分等)依赖纯函数的特性,用不纯的函数可能导致错误或静默失败,结果完全不对。

# 纯函数,没问题
def pure_addition(a, b):
return a + b
# 不纯的函数,JAX 不接受
counter = 0
def impure_addition(a, b):
global counter
counter += 1
return a + b

JAX NumPy 与原生 NumPy

JAX 提供了类 NumPy 的接口,核心优势在于能自动高效地在 CPU、GPU 甚至 TPU 上运行,支持本地或分布式执行。这套能力来自XLA(Accelerated Linear Algebra)编译器,它把 JAX 代码翻译成针对不同硬件优化的机器码。

NumPy 默认只在 CPU 上跑,JAX NumPy 则不同。用法上两者很相似,这也是 JAX 容易上手的原因。


# JAX 也差不多
import jax.numpy as jnp
print(jnp.sqrt(4))# NumPy 的写法
import numpy as np
print(np.sqrt(4))
# JAX 也差不多
import jax.numpy as jnp
print(jnp.sqrt(4))

常见的操作两者看起来基本一样:

import numpy as np
import jax.numpy as jnp
# 创建数组
np_a = np.array([1.0, 2.0, 3.0])
jnp_a = jnp.array([1.0, 2.0, 3.0])
# 元素级操作
print(np_a + 2)
print(jnp_a + 2)
# 广播
np_b = np.array([[1, 2, 3]])
jnp_b = jnp.array([[1, 2, 3]])
print(np_b + np.arange(3))
print(jnp_b + jnp.arange(3))
# 求和
print(np.sum(np_a))
print(jnp.sum(jnp_a))
# 平均值
print(np.mean(np_a))
print(jnp.mean(jnp_a))
# 点积
print(np.dot(np_a, np_a))
print(jnp.dot(jnp_a, jnp_a))

但有个重要差异需要注意:

JAX 数组是不可变的,对数组的修改操作会返回新数组而不是改变原数组。

NumPy 数组则可以直接修改:

import numpy as np
x = np.array([1, 2, 3])
x[0] = 10 # 直接修改,没问题

JAX 这边就不行了:

import jax.numpy as jnp
x = jnp.array([1, 2, 3])
x[0] = 10 # 报错

但是JAX 提供了专门的 API 来处理这种情况,通过返回一个新数组的方式实现"修改":

z = x.at[idx].set(y)

完整的例子:

x = jnp.array([1, 2, 3])
y = x.at[0].set(10)
print(y) # [10, 2, 3]
print(x) # [1, 2, 3](没变)

JIT 编译加速

即时编译(JIT)是 JAX 里一个核心特性,通过 XLA 把 Python/JAX 代码编译成优化后的机器码。

直接用 Python 解释器跑函数会很慢。加上 @jit 装饰器后,函数会被编译成快速的原生代码:

from jax import jit
# 不编译的版本
def square(x):
return x * x
# 编译过的版本
@jit
def jit_square(x):
return x * x

jit_square 快好几个数量级。函数首次调用时,JIT 引擎会:

  1. 追踪函数逻辑,构建计算图
  2. 把图编译成优化的 XLA 代码
  3. 缓存编译结果
  4. 后续调用直接用缓存的版本

自动微分

JAX 的grad模块能自动计算函数的导数。

import jax.numpy as jnp
from jax import grad
# 定义函数:f(x) = x² + 2x + 2
def f(x):
return x**2 + 2 * x + 2
# 计算导数
df_dx = grad(f)
# 在 x = 2.0 处求值
print(df_dx(2.0)) # 6.0

随机数处理

NumPy 用全局随机状态生成随机数。每次调用 np.random.random() 时,NumPy 会更新隐藏的内部状态:

import numpy as np
np.random.random()
# 0.9539264374520571

JAX 的做法完全不同。作为纯函数库,它不能维护全局状态,所以要求显式传入一个伪随机数生成器(PRNG)密钥。每次生成随机数前要先分割密钥:

from jax import random
# 初始化密钥
key = random.PRNGKey(0)
# 每次生成前分割
key, subkey = random.split(key)
# 从正态分布采样
x = random.normal(subkey, ())
print(x) # -2.4424558
# 从均匀分布采样
key, subkey = random.split(key)
u = random.uniform(subkey, (), minval=0.0, maxval=1.0)
print(u) # 0.104290366

一个常见的坑:同一个密钥生成的随机数始终相同。

# 用同一个 subkey,结果重复
x = random.normal(subkey, ())
print(x) # -2.4424558
x = random.normal(subkey, ())
print(x) # -2.4424558(还是这个值)

所以要记住总是用新密钥。

向量化:vmap

vmap自动把函数转换成能处理批量数据的版本。逻辑上就像循环遍历每个样本,但执行效率远高于 Python 循环。

import jax.numpy as jnp
from jax import vmap
def f(x):
return x * x + 1
arr = jnp.array([1., 2., 3., 4.])
# Python 循环(慢)
outputs_loop = jnp.array([f(x) for x in arr])
# vmap 版本(快)
f_vectorized = vmap(f)
outputs_vmap = f_vectorized(arr)

并行化:pmap

pmap不同于 vmap。vmap 在单个设备上做批处理,pmap 把计算分散到多个设备(GPU/TPU 核心),每个设备处理输入的一部分。

VMAP:单设备批处理向量化

PMAP:跨多设备并行执行

import jax.numpy as jnp
from jax import pmap
# 查看可用设备
print(jax.devices()) # [TpuDevice(id=0), TpuDevice(id=1), ..., TpuDevice(id=7)]
def f(x):
return x * x + 1
arr = jnp.array([1., 2., 3., 4.])
# pmap 会把数组分配到不同设备
ys = pmap(f)(arr)

PyTrees

PyTree 在 JAX 里是个常见的概念:任何嵌套的 Python 容器(列表、字典、元组等)加上基本类型的组合。JAX 里用它来组织模型参数、优化器状态、梯度等。

import jax.numpy as jnp
from jax import tree_util as tu
# 构建 PyTree
pytree = {
"a": jnp.array([1, 2]),
"b": [jnp.array([3, 4]), 5]
}
# 获取所有叶子节点
leaves = tu.tree_leaves(pytree)
# 对每个叶子应用函数
doubled = tu.tree_map(lambda x: x * 2, pytree)

Optax:梯度处理和优化

Optax 是 JAX 生态里的优化库。它包含损失函数、优化器、梯度变换、学习率调度等一套工具。

损失函数:

logits = jnp.array([[2.0, -1.0]])
labels_onehot = jnp.array([[1.0, 0.0]])
labels_int = jnp.array([0])
# Softmax 交叉熵(独热编码)
loss_softmax_onehot = optax.softmax_cross_entropy(logits, labels_onehot).mean()
# Softmax 交叉熵(整数标签)
loss_softmax_int = optax.softmax_cross_entropy_with_integer_labels(logits, labels_int).mean()
# 二元交叉熵
loss_bce = optax.sigmoid_binary_cross_entropy(logits, labels_onehot).mean()
# L2 损失
loss_l2 = optax.l2_loss(jnp.array([1., 2.]), jnp.array([0., 1.])).mean()
# Huber 损失
loss_huber = optax.huber_loss(jnp.array([1.,2.]), jnp.array([0.,1.])).mean()

优化器:

# SGD
opt_sgd = optax.sgd(learning_rate=1e-2)
# SGD with momentum
opt_momentum = optax.sgd(learning_rate=1e-2, momentum=0.9)
# RMSProp
opt_rmsprop = optax.rmsprop(1e-3)
# Adafactor
opt_adafactor = optax.adafactor(learning_rate=1e-3)
# Adam
opt_adam = optax.adam(1e-3)
# AdamW
opt_adamw = optax.adamw(1e-3, weight_decay=1e-4)

梯度变换:

# 梯度裁剪
tx_clip = optax.clip(1.0)
# 全局梯度范数裁剪
tx_clip_global = optax.clip_by_global_norm(1.0)
# 权重衰减(L2)
tx_weight_decay = optax.add_decayed_weights(1e-4)
# 添加梯度噪声
tx_noise = optax.add_noise(0.01)

学习率调度:

# 指数衰减
lr_exp = optax.exponential_decay(init_value=1e-3, transition_steps=1000, decay_rate=0.99)
# 余弦衰减
lr_cos = optax.cosine_decay_schedule(init_value=1e-3, decay_steps=10_000)
# 线性预热
lr_linear = optax.linear_schedule(init_value=0.0, end_value=1e-3, transition_steps=500)

更新步骤:

# 计算梯度
loss, grads = jax.value_and_grad(loss_fn)(params)
# 生成优化器更新
updates, opt_state = optimizer.update(grads, opt_state)
# 应用更新
params = optax.apply_updates(params, updates)

链式组合:

# 把多个操作链起来
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # 梯度裁剪
optax.add_decayed_weights(1e-4), # 权重衰减
optax.adam(1e-3) # Adam 优化
)

Flax 与神经网络

JAX 本身只是数值计算库,Flax 在其基础上提供了神经网络定义和训练的高级 API。Flax 代码风格接近 PyTorch,如果你用过 PyTorch 会很快上手。

Flax 提供了丰富的层和操作。基础层包括全连接层 Dense、卷积 Conv、嵌入 Embed、多头注意力 MultiHeadDotProductAttention 等:

flax.linen.Dense(features=128)
flax.linen.Conv(features=64, kernel_size=(3, 3))
flax.linen.Embed(num_embeddings=10000, features=256)
flax.linen.MultiHeadDotProductAttention(num_heads=8)
flax.linen.SelfAttention(num_heads=8)

归一化支持多种方式:

flax.linen.BatchNorm()
flax.linen.LayerNorm()
flax.linen.GroupNorm(num_groups=32)
flax.linen.RMSNorm()

激活和 Dropout:

flax.linen.relu(x)
flax.linen.gelu(x)
flax.linen.sigmoid(x)
flax.linen.tanh(x)
flax.linen.Dropout(rate=0.1)

池化:

flax.linen.avg_pool(x, window_shape=(2,2), strides=(2,2))
flax.linen.max_pool(x, window_shape=(2,2), strides=(2,2))

循环层:

flax.linen.LSTMCell()
flax.linen.GRUCell()
flax.linen.OptimizedLSTMCell()

下面是一个简单的多层感知机(MLP)例子:

import jax
import jax.numpy as jnp
from flax import linen as nn
class MLP(nn.Module):
features: list
@nn.compact
def __call__(self, x):
for f in self.features[:-1]:
x = nn.Dense(f)(x)
x = nn.relu(x)
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([32, 16, 10])
key = jax.random.PRNGKey(0)
# 输入:batch_size=1, 特征数=4
x = jnp.ones((1, 4))
# 初始化参数
params = model.init(key, x)
# 前向传播
y = model.apply(params, x)
print("Input:", x)
# Input: [[1. 1. 1. 1.]]
print("Input shape:", x.shape)
# Input shape: (1, 4)
print("Output:", y)
# Output: [[ 0.51415515 0.36979797 0.6212194 -0.74496573 -0.8318489 0.6590691 0.89224255 0.00737424 0.33062232 0.34577468]]
print("Output shape:", y.shape)
# Output shape: (1, 10)

Flax 用 @nn.compact 装饰器,让你在 __call__ 方法里直接定义层。参数是独立于模型对象存储的,需要通过 init 方法显式初始化,然后在 apply 方法中使用。

总结

JAX 的出现解决了一个长期存在的问题:如何让 Python 科学计算既保持灵活性,又能获得接近 C/CUDA 的性能。

不过 JAX 的学习曲线确实比 PyTorch 陡。纯函数的约束、不可变数组的特性、显式密钥管理等细节起初会有些别扭。但一旦习惯会发现它带来的优雅和灵活性。

https://avoid.overfit.cn/post/a16194fdc3ea450f858515d7cb3d49c4

作者:Ashish Bamania

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
港府官员访美“急刹车”,中央政府下达指令,绝不给美国可乘之机

港府官员访美“急刹车”,中央政府下达指令,绝不给美国可乘之机

依偎在角落
2026-01-09 10:21:16
超出很多人预料!女排联赛第八轮爆出两大冷门

超出很多人预料!女排联赛第八轮爆出两大冷门

老高说体育
2026-01-11 12:12:55
深挖范继文:培养奥运冠军被记功!一手遮天猥亵女队员 妻子曝光

深挖范继文:培养奥运冠军被记功!一手遮天猥亵女队员 妻子曝光

念洲
2026-01-11 12:02:37
羊肉再次被关注!医生发现:高血糖人吃羊肉,不过多久或有几益处

羊肉再次被关注!医生发现:高血糖人吃羊肉,不过多久或有几益处

阿兵科普
2026-01-10 22:55:47
19岁河南小伙在伦敦“捡”了个瑞士爷爷,回国后硬是把他接来郑州养老,临终前老人说了一句话,让无数人破防

19岁河南小伙在伦敦“捡”了个瑞士爷爷,回国后硬是把他接来郑州养老,临终前老人说了一句话,让无数人破防

源溯历史
2026-01-09 18:47:21
王腾回应新公司为何不招应届生:草台班子刚起步,待业务发展起来后欢迎加入

王腾回应新公司为何不招应届生:草台班子刚起步,待业务发展起来后欢迎加入

新浪财经
2026-01-09 12:52:51
光明正大摸鱼!25岁女子在初创公司没事做,老板:你无聊就看电视

光明正大摸鱼!25岁女子在初创公司没事做,老板:你无聊就看电视

唐小糖说情感
2026-01-10 21:27:48
你听过最劲爆的瓜是啥?网友:被大八岁的补习班老师表白了

你听过最劲爆的瓜是啥?网友:被大八岁的补习班老师表白了

带你感受人间冷暖
2025-11-26 00:10:06
牡丹花下死!46岁"纵欲过度"的萧亚轩,终是为自己行为买了单

牡丹花下死!46岁"纵欲过度"的萧亚轩,终是为自己行为买了单

有趣的胡侃
2026-01-11 11:37:15
36岁TVB视后终于有剧拍!直呼单身很久,不用金牌媒人做媒!

36岁TVB视后终于有剧拍!直呼单身很久,不用金牌媒人做媒!

我爱追港剧
2026-01-10 23:41:08
西方越想越害怕!俄军榛树高超不用弹头,170亿立方米天然气泄露

西方越想越害怕!俄军榛树高超不用弹头,170亿立方米天然气泄露

松林看世界
2026-01-11 07:07:34
广东新增一所本科高校

广东新增一所本科高校

鲁中晨报
2026-01-11 11:29:02
A股:从下周起,或许历史将惊人相似!4500点大级别主升浪要来了

A股:从下周起,或许历史将惊人相似!4500点大级别主升浪要来了

夜深爱杂谈
2026-01-10 21:41:39
一顿乱扔!11投0中,全场0分0板1助,媒体人:需苦练,球迷:一根筋

一顿乱扔!11投0中,全场0分0板1助,媒体人:需苦练,球迷:一根筋

金山话体育
2026-01-11 07:12:00
油价大跌超1.7元/升,大降到6元时代的汽柴油,1月20日或再次下跌

油价大跌超1.7元/升,大降到6元时代的汽柴油,1月20日或再次下跌

油价早知道
2026-01-11 01:31:10
上海一女子肩膀疼以为是肩周炎,1周后离世,医生怒斥:太无知

上海一女子肩膀疼以为是肩周炎,1周后离世,医生怒斥:太无知

刘哥谈体育
2026-01-10 13:40:30
闫学晶酸黄瓜事件升级!其过往婚史被扒,人脉金钱两手抓,不简单

闫学晶酸黄瓜事件升级!其过往婚史被扒,人脉金钱两手抓,不简单

深析古今
2026-01-09 11:04:08
他8次上春晚,作死被捕入狱,如今56岁无人问津,沦落到四处走穴

他8次上春晚,作死被捕入狱,如今56岁无人问津,沦落到四处走穴

小熊侃史
2026-01-06 11:17:00
U23国足VS澳大利亚:彭啸坐镇 海港天才新星临危受命 王钰栋冲锋

U23国足VS澳大利亚:彭啸坐镇 海港天才新星临危受命 王钰栋冲锋

零度眼看球
2026-01-11 07:20:06
全球用户大面积中招:鼠标突然就“坏了”!不少人按到“手抽筋”,重装卸载也不管用,罗技回应

全球用户大面积中招:鼠标突然就“坏了”!不少人按到“手抽筋”,重装卸载也不管用,罗技回应

每日经济新闻
2026-01-08 20:15:12
2026-01-11 13:00:49
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1886文章数 1441关注度
往期回顾 全部

科技要闻

“我们与美国的差距也许还在拉大”

头条要闻

马杜罗之子:马杜罗在美监狱说"我们很好我们是斗士"

头条要闻

马杜罗之子:马杜罗在美监狱说"我们很好我们是斗士"

体育要闻

詹皇晒照不满打手没哨 裁判报告最后两分钟无误判

娱乐要闻

网友偶遇贾玲张小斐崇礼滑雪

财经要闻

外卖平台"烧钱抢存量市场"迎来终局?

汽车要闻

2026款宋Pro DM-i长续航补贴后9.98万起

态度原创

时尚
教育
家居
健康
亲子

动物纹回潮,那很狂野了

教育要闻

别抱怨你的孩子找不到工作,原因可能是这几个,要高度重视

家居要闻

木色留白 演绎现代自由

这些新疗法,让化疗不再那么痛苦

亲子要闻

深度长文:原始社会婴儿哭声会引来大量天敌,婴儿如何生存下看?

无障碍浏览 进入关怀版