网易首页 > 网易号 > 正文 申请入驻

Scikit-Learn 1.8更新:新增 Array API 支持PyTorch与CuPy加速!

0
分享至


Scikit-Learn 1.8.0 更新引入了实验性的 Array API 支持。这意味着 CuPy 数组或 PyTorch 张量现在可以直接在 Scikit-Learn 的部分组件中直接使用了,且计算过程能保留在 GPU 上。



1.8.0 到底更新了什么?

Scikit-Learn 开始正式支持Python Array API 标准。这是一个由 NumPy、CuPy、PyTorch、JAX 等库共同维护的接口规范。在 1.8.0 版本中可以实现:

  • 直接传参:受支持的评估器(estimators)现在可以直接接收 CuPy 数组或 PyTorch 张量。
  • 计算分派:运算会被自动分派到对应的非 CPU 设备(如 GPU)上执行。
  • 状态保留:模型拟合后的属性会与输入数据保持在同一物理设备上。

虽然目前的版本依然贴着“实验性”标签且需要显式开启,但它确实打破了 Scikit-Learn 过去那种“万物皆需 NumPy”的框架。

交叉验证

如果你平时不怎么用 cross_val_score、GridSearchCV 或 CalibratedClassifierCV,那你可能感觉不到这次更新的提速。但对大多数从事肃建模的开发者来说,交叉验证一直是 GPU 的“性能杀手”。

在旧版本中,即便你的基础模型(如 XGBoost)是在 GPU 上训练的,Scikit-Learn 的编排逻辑会把数组转回 NumPy,然后在 CPU 上重新计算各项指标。这种频繁的内存搬运和 CPU 的操作浪费了大量的时间,但是Array API 的加入让这种循环能基本闭环在 GPU 内部运行。

开启方式与限制

启用这项特性需要完成下面的配置。如果漏掉任何一步,程序都会悄悄退回到 NumPy 模式。

环境变量设置(必须在导入 SciPy 或 Scikit-Learn 之前):

import os
os.environ["SCIPY_ARRAY_API"] = "1"

配置 Scikit-Learn 内部开关

from sklearn import set_config
set_config(array_api_dispatch=True)

目前还有一个问题,就是不支持cuDF DataFrames。但是你依然可以用 cuDF 做数据加载和预处理,不过输入模型之前必须确保输入是 array-like 格式。也就是说类别特征必须手动编码而且且无法再依赖 pandas/cuDF 的 dtype 自动识别机制。

基于 GPU 的 XGBoost 交叉验证

下面是一个运行 5 折分层交叉验证的示例。为了让整个链路留在 GPU 上,我们需要对 XGBClassifier 做一点小的封装,并结合 cuML 的指标计算。

import os
os.environ['SCIPY_ARRAY_API'] = '1'
import cupy as cp
import cudf
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.metrics import make_scorer
from cuml.metrics import roc_auc_score
from xgboost import XGBClassifier
from sklearn import set_config
set_config(array_api_dispatch=True)
# 加载数据并进行简单的预处理
X = cudf.read_csv('/kaggle/input/playground-series-s5e12/train.csv').set_index('id')
y = X.pop('diagnosed_diabetes').astype(int)
# 类别特征编码处理
cat_cols = [c for c in X.columns if X[c].dtype == 'object']
X = X.astype({c: 'category' for c in cat_cols})
for c in cat_cols:
X[c] = X[c].cat.codes
ft = ['c' if c in cat_cols else 'q' for c in X.columns]
kfold = StratifiedKFold(5, shuffle=True, random_state=0)
# 封装 XGB 以适配 CuPy 预测
class cuXGBClassifier(XGBClassifier):
@property
def classes_(self):
return cp.asarray(super().classes_)
def predict_proba(self, X):
p = self.get_booster().inplace_predict(X)
if p.ndim == 1:
p = cp.column_stack([1 - p, p])
return p
def predict(self, X):
return cp.asarray(super().predict(X))
model = cuXGBClassifier(
enable_categorical=True,
feature_types=ft,
device='cuda',
n_jobs=4,
random_state=0
)
# 执行交叉验证
scores = cross_val_score(
model,
X.values,
y.values,
cv=kfold,
scoring=make_scorer(
roc_auc_score,
response_method="predict_proba"
),
n_jobs=1
)
print(f"{scores.mean():.5f} ± {scores.std():.5f}")

虽然这段代码看起来还是需要一些修改,但它确实能让交叉验证循环保持在 GPU 上。

现阶段支持的组件

目前 Array API 的覆盖范围还在逐步扩大。在 1.8.0 中,以下组件已经具备了较好的支持:

  • 预处理:StandardScaler、PolynomialFeatures
  • 线性模型与校准:RidgeCV、RidgeClassifierCV、CalibratedClassifierCV
  • 聚类与混合模型:GaussianMixture

官方提供的一个基于 PyTorch 的 Ridge 管道示例显示,在处理线性代数密集型任务时,这种配置在 Colab 环境下能比单核 CPU 快出 10 倍左右。

ridge_pipeline_gpu = make_pipeline(
feature_preprocessor,
FunctionTransformer(
lambda x: torch.tensor(
x.to_numpy().astype(np.float32),
device="cuda"
)
),
CalibratedClassifierCV(
RidgeClassifierCV(alphas=alphas),
method="temperature"
),
)
with sklearn.config_context(array_api_dispatch=True):
cv_results = cross_validate(
ridge_pipeline_gpu, features, target
)

总结

Scikit-Learn 准备好完全接管 GPU 了吗?显然还没有。但这个版本意义在于,它正已经向GPU的支持迈出了第一步。目前这种方式虽然还有点“硬核”,对普通用户不够友好,但对于追求极致效率的开发者来说,Scikit-Learn 1.8.0 已经要想这个方向前进了。

https://avoid.overfit.cn/post/ab7e632896364fc3b4b9fdc9d17884e3

作者:Abish Pius

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
心梗去世逐年增多?医生:牢记“3不喝、2不吃、1坚持”,别大意

心梗去世逐年增多?医生:牢记“3不喝、2不吃、1坚持”,别大意

袁医生课堂
2026-01-24 17:33:06
田家慌了!全网复刻田氏艺术,85岁雕塑家的遮羞布藏不住了

田家慌了!全网复刻田氏艺术,85岁雕塑家的遮羞布藏不住了

离离言几许
2026-01-26 12:26:16
央视发声后,新华社点评李亚鹏更“猛”,真正道出了老百姓的心声

央视发声后,新华社点评李亚鹏更“猛”,真正道出了老百姓的心声

娱乐故事
2026-01-26 17:22:02
雪豹咬人事件后续:闺蜜曝真相,美女游客或毁容,现场细节太惊心

雪豹咬人事件后续:闺蜜曝真相,美女游客或毁容,现场细节太惊心

复转这些年
2026-01-26 23:24:31
杨瀚森入选全明星新秀赛!为发展联盟代表队出战 教练是小里弗斯

杨瀚森入选全明星新秀赛!为发展联盟代表队出战 教练是小里弗斯

罗说NBA
2026-01-27 08:01:37
重磅:乌克兰突袭攻入俄罗斯领土!摧毁库尔斯克指挥所

重磅:乌克兰突袭攻入俄罗斯领土!摧毁库尔斯克指挥所

项鹏飞
2026-01-26 17:11:27
无保护徒手登顶台北101能拿多少钱?攀岩大神告诉你,报酬很尴尬

无保护徒手登顶台北101能拿多少钱?攀岩大神告诉你,报酬很尴尬

译言
2026-01-26 10:02:40
思想配得上苦难

思想配得上苦难

求实处
2026-01-26 22:20:03
李国庆被当当啪啪打脸了

李国庆被当当啪啪打脸了

不正确
2026-01-26 20:02:53
何庆魁喊话赵本山:我一个人支撑本山传媒好几年,写剧本累伤了!

何庆魁喊话赵本山:我一个人支撑本山传媒好几年,写剧本累伤了!

离离言几许
2026-01-27 00:17:24
世卫组织警告:饮用生枣椰汁前务必煮沸

世卫组织警告:饮用生枣椰汁前务必煮沸

上观新闻
2026-01-27 07:01:19
荣耀高管评iPhoneAir降价:将彻底死透!魅族前高管看不下去怒怼

荣耀高管评iPhoneAir降价:将彻底死透!魅族前高管看不下去怒怼

柴狗夫斯基
2026-01-26 08:50:27
科学家立大功!中科大解决“固态电池”量产难题,成本可降20倍

科学家立大功!中科大解决“固态电池”量产难题,成本可降20倍

胖福的小木屋
2026-01-25 21:00:49
确保思政教师收入高于其他专业,高校这一政策须兼顾公平

确保思政教师收入高于其他专业,高校这一政策须兼顾公平

读鬼笔记
2026-01-26 20:26:20
首映仅150万,《舒克贝塔》票房扑街,郑渊洁父子亏到怀疑人生

首映仅150万,《舒克贝塔》票房扑街,郑渊洁父子亏到怀疑人生

电影票房预告片
2026-01-25 00:00:56
14岁陈佳铭已昏迷超30天,上海专家会诊后,确认无生还希望

14岁陈佳铭已昏迷超30天,上海专家会诊后,确认无生还希望

离离言几许
2026-01-26 16:15:54
越南共产党新一届领导集体亮相 发展新局待启,越南将走向何方

越南共产党新一届领导集体亮相 发展新局待启,越南将走向何方

清水阿娇
2026-01-27 06:50:03
一记重拳!中国发外交照会,限日本6个月交出,118年前掠走的唐碑

一记重拳!中国发外交照会,限日本6个月交出,118年前掠走的唐碑

策略述
2026-01-26 12:32:25
一中华老字号国企董事长,打伤要债人

一中华老字号国企董事长,打伤要债人

中国新闻周刊
2026-01-26 19:31:17
退脏衣女记者全网社死!正脸很白净,坏到骨子里,山东文旅遭围攻

退脏衣女记者全网社死!正脸很白净,坏到骨子里,山东文旅遭围攻

李健政观察
2026-01-26 09:33:07
2026-01-27 08:59:00
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1902文章数 1445关注度
往期回顾 全部

科技要闻

理想开始关店“过冬”,否认“百家”规模

头条要闻

牛弹琴:韩国人万万没想到在睡梦中 特朗普突然下手了

头条要闻

牛弹琴:韩国人万万没想到在睡梦中 特朗普突然下手了

体育要闻

叛逆的大公子,要砸了贝克汉姆这块招牌

娱乐要闻

张雨绮被抵制成功!辽视春晚已将她除名

财经要闻

金价狂飙 “牛市神话”未完待续

汽车要闻

宾利第四台Batur敞篷版发布 解锁四项定制创新

态度原创

时尚
教育
游戏
旅游
军事航空

这些韩系穿搭最适合普通人!多穿深色、衣服基础,简洁耐看

教育要闻

针对海岛人口小县教育发展难题,浙江嵊泗县取消中考选拔功能,2025学年全县266名填报普高的初三毕业...

LPL最强战队易主!JDG双杀BLG登顶LPL,国一教实至名归?

旅游要闻

这个冬天 乌兰察布献上不一样的那达慕

军事要闻

委代总统称遭美威胁:马杜罗已死

无障碍浏览 进入关怀版