网易首页 > 网易号 > 正文 申请入驻

Google开源Tunix:JAX生态的LLM微调方案来了

0
分享至

JAX生态这两年在LLM训练这块追赶得挺快。PyTorch虽然还是主流但JAX在并行计算、TPU加速和API组合性上确实有些独特的优势。Google今天放出了Tunix这个库,专门做LLM的后训练——微调、强化学习、知识蒸馏这些都能搞。



Tunix是什么

这是个构建在JAX之上的后训练库,和Flax NNX集成得比较紧密。主要解决三类问题:

  • 监督微调(Supervised Fine-Tuning)
  • 强化学习(Reinforcement Learning)
  • 知识蒸馏(Knowledge Distillation)

现在还在早期开发阶段,功能在持续迭代,支持的模型也在慢慢扩展。

核心功能

监督微调:既支持全参数微调,也支持LoRA和Q-LoRA这类参数高效的方法。内存和算力受限的时候,PEFT方案还是挺实用的。

强化学习:实现了几个主流算法:PPO(Proximal Policy Optimization)、GRPO(Group Relative Policy Optimization)、还有token级别的GSPO。另外还有DPO(Direct Preference Optimization)做偏好对齐,这个在RLHF场景用得比较多。

知识蒸馏:支持几种策略,包括基于logit的概率分布匹配、注意力机制的转移和投影、跨架构的特征池化与投影。这几种方法在不同场景下各有用处。

库的设计比较模块化,组件可以自由组合,想扩展自定义流程也不算麻烦。分布式训练支持数据并行(DP)、完全分片数据并行(FSDP)和张量并行(TP),对TPU做了专门优化。

安装

三种装法:

从PyPI装(推荐):

pip install "tunix[prod]"

或者直接从GitHub主分支:

pip install git+https://github.com/google/tunix

开发模式从源码装:

git clone https://github.com/google/tunix.git
cd tunix
pip install -e ".[dev]"

TPU上用QLoRA微调Gemma

拿个英译法的任务来演示。用的是Google的Gemma 2B模型,跑在TPU v5e-8上。

环境准备

pip install -q kagglehub safetensors tensorflow tensorflow_datasets tensorboardX transformers grain datasets
pip install -q git+https://github.com/google/tunix
pip install -q git+https://github.com/google/qwix
# Flax需要升级到最新版
pip uninstall -q -y flax
pip install -q git+https://github.com/google/flax.git

完整流程

第一步,从Kaggle拉预训练checkpoint:

import kagglehub
model_path = "google/gemma/flax/2b"
kaggle_ckpt_path = kagglehub.model_download(model_path)

初始化模型和tokenizer:

from flax import nnx
from tunix.models.gemma import model as gemma_lib, params as params_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
base_model = gemma_lib.Transformer.from_params(
params_lib.load_and_format_params(kaggle_ckpt_path, "2b"),
version="2b"
)
tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=f"{kaggle_ckpt_path}/tokenizer.model")

挂上QLoRA adapter:

import qwix
lora_provider = qwix.LoraProvider(
module_path=".*(q_einsum|kv_einsum|proj)",
rank=16,
alpha=2.0,
weight_qtype="nf4" # enable QLoRA quantization
)
lora_model = qwix.apply_lora_to_model(base_model, lora_provider)

这里rank设成16,alpha是2.0,weight_qtype指定nf4量化格式。

加载训练数据:

from tunix.examples.data import translation_dataset
train_ds, validation_ds = translation_dataset.create_datasets(
dataset_name="mtnt/en-fr",
global_batch_size=16,
max_target_length=256,
num_train_epochs=3,
tokenizer=tokenizer,
)

用的是mtnt的英法平行语料,batch size 16,目标序列最长256个token。

开始训练:

from tunix.sft import peft_trainer, utils
import optax
trainer = peft_trainer.PeftTrainer(
lora_model,
optimizer=optax.adamw(1e-3),
config=peft_trainer.TrainingConfig(max_steps=100)
)
trainer.train(train_ds, validation_ds)

优化器用AdamW,学习率1e-3,跑100步看看效果。

推理测试:

训练完直接用adapter过的模型做生成。Tunix提供了Sampler工具:

from tunix.generate import sampler as sampler_lib
# initialize sampler
sampler = sampler_lib.Sampler(
transformer=lora_model,
tokenizer=tokenizer,
cache_config=sampler_lib.CacheConfig(
cache_size=256,
num_layers=base_model.num_layers,
num_kv_heads=base_model.num_kv_heads,
head_dim=base_model.head_dim,
),
)
# test prompts
input_batch = [
"Translate this into French:\nHello, my name is Morgane.\n",
"Translate this into French:\nThis dish is delicious!\n",
"Translate this into French:\nI am a student.\n",
"Translate this into French:\nHow's the weather today?\n",
]
# generate predictions
out_data = sampler(
input_strings=input_batch,
max_generation_steps=20,
)
# print results
for input_string, out_string in zip(input_batch, out_data.text):
print(f"----------------------")
print(f"Prompt:\n{input_string}")
print(f"Output:\n{out_string}")

如果用的是QLoRA,把lora_model换成qlora_model就行。生产环境可以考虑把adapter合并回基模型,推理延迟能降下来。

总结

100步训练之后,模型已经能生成一些翻译结果了,虽然质量还不够好。多训练一段时间,准确率会明显提升,而且内存开销和训练速度都保持在不错的水平。

Tunix现在还比较新,但已经能看出一些潜力。TPU优先的设计、模块化的API、LoRA/QLoRA支持、完整的分布式训练策略,这些对做LLM适配研究的人来说都挺有用。

后续应该会继续扩展支持的模型类型和训练算法,值得关注。

地址:https://avoid.overfit.cn/post/c434311d8a894922b6c52ea179cf8d97

作者:Abish Pius

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
不顾央视警告顶风作案,与刘涛传出绯闻的杨烁,如今怎么样了?

不顾央视警告顶风作案,与刘涛传出绯闻的杨烁,如今怎么样了?

阿柒的讯
2025-12-08 20:27:42
就在今天!12月9日早晨,乒乓球领域传来张本智和消息

就在今天!12月9日早晨,乒乓球领域传来张本智和消息

皮皮观天下
2025-12-09 05:17:09
74岁刘銮雄坐轮椅5人伺候,甘比搀扶起身行走,每一步都小心翼翼

74岁刘銮雄坐轮椅5人伺候,甘比搀扶起身行走,每一步都小心翼翼

照见古今
2025-12-09 18:38:47
日本民宿被曝变身卖淫场,性工作者称中国游客更大方。

日本民宿被曝变身卖淫场,性工作者称中国游客更大方。

环球趣闻分享
2025-11-09 14:20:06
马上通车!南京2条重磅地铁通过验收!另外7条线路最新进展来了

马上通车!南京2条重磅地铁通过验收!另外7条线路最新进展来了

笑谈历史阿晡
2025-12-08 15:43:29
突发调整!12月10日WTT总决赛直播有变,CCTV5不直播“中日大战”

突发调整!12月10日WTT总决赛直播有变,CCTV5不直播“中日大战”

墨印斋
2025-12-09 18:41:00
他是泰山队最被高估的头牌,两年过去依旧0进球,彻底泯然众人矣

他是泰山队最被高估的头牌,两年过去依旧0进球,彻底泯然众人矣

体坛风之子
2025-12-09 04:30:04
这块奇怪的石头,在火星热带雨林中被大雨冲刷了几百万年

这块奇怪的石头,在火星热带雨林中被大雨冲刷了几百万年

星空天文
2025-12-08 20:18:29
丰田终于放下身段!RAV4荣放最高优惠4.5万,你会选择它吗?

丰田终于放下身段!RAV4荣放最高优惠4.5万,你会选择它吗?

汽车网评
2025-12-09 20:58:32
中日战机对峙后,高市早苗紧急表态,美防长敲打日本:将付出代价

中日战机对峙后,高市早苗紧急表态,美防长敲打日本:将付出代价

云鹏叙事
2025-12-08 11:04:38
“短命首相”倒计时?高市突然收到坏消息,一封检举信引爆舆论

“短命首相”倒计时?高市突然收到坏消息,一封检举信引爆舆论

男女那点事儿儿
2025-12-09 15:36:05
为什么人到中年很少有身材苗条的呢?网友:身材管理是个奢侈品

为什么人到中年很少有身材苗条的呢?网友:身材管理是个奢侈品

夜深爱杂谈
2025-12-08 20:08:11
作用全面经验丰富!马刺真的应该给这位内线老将多些出场时间?

作用全面经验丰富!马刺真的应该给这位内线老将多些出场时间?

稻谷与小麦
2025-12-09 22:40:47
利物浦换帅候选锁定两人:一人为实力之选,另一人暗藏巨大风险

利物浦换帅候选锁定两人:一人为实力之选,另一人暗藏巨大风险

夜白侃球
2025-12-09 21:54:05
三军总司令亲自动手,前总理生死不明,“巴铁”撕裂时刻到来?

三军总司令亲自动手,前总理生死不明,“巴铁”撕裂时刻到来?

东方点兵
2025-12-09 15:36:32
青海一中学校长被通报:培训期间陪妻子到八达岭长城等景点游玩,上月已受警告处分

青海一中学校长被通报:培训期间陪妻子到八达岭长城等景点游玩,上月已受警告处分

极目新闻
2025-12-09 16:53:08
中国无法原谅的“6大国家”,日本居然仅排第二,第一出乎意料?

中国无法原谅的“6大国家”,日本居然仅排第二,第一出乎意料?

爱吃醋的猫咪
2025-11-27 17:48:57
1950年11月29日,一位志愿军排长无视上级命令,放弃驻守高地

1950年11月29日,一位志愿军排长无视上级命令,放弃驻守高地

忠于法纪
2025-12-09 21:34:20
1951年,原马家军改编的解放军7师,部分官兵发动叛乱

1951年,原马家军改编的解放军7师,部分官兵发动叛乱

忠于法纪
2025-12-07 11:15:11
川普猛烈抨击欧洲;乌克兰将改-革权力机制

川普猛烈抨击欧洲;乌克兰将改-革权力机制

近距离
2025-12-09 18:16:39
2025-12-09 23:12:49
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1854文章数 1439关注度
往期回顾 全部

科技要闻

H200是不是要让中国“上瘾”?

头条要闻

31岁中国女留学生让26岁外籍男友检测性病 遭残忍杀害

头条要闻

31岁中国女留学生让26岁外籍男友检测性病 遭残忍杀害

体育要闻

“苏炳添时代”正式画上句号

娱乐要闻

尖叫之夜刘宇宁打包饼干被嘲寒酸?

财经要闻

县城经济神话,梦醒时分

汽车要闻

旗舰巨作 鸿蒙智行首款MPV智界V9信息披露

态度原创

亲子
游戏
旅游
本地
公开课

亲子要闻

惊呆!怀孕时做过的那些梦能有离谱!网友:太神奇了!

Kespa杯:T1与HLE的恩怨局,Scout和许秀的强强对话

旅游要闻

日本青森地震已致超40人受伤!有中国游客称取消赴日行程

本地新闻

云游安徽|一城活史,千年智慧守淮南

公开课

李玫瑾:为什么性格比能力更重要?

无障碍浏览 进入关怀版