网易首页 > 网易号 > 正文 申请入驻

TensorFlow 决策森林详细介绍和使用说明

0
分享至

使用TensorFlow训练、调优、评估、解释和部署基于树的模型的完整教程

两年前TensorFlow (TF)团队开源了一个库来训练基于树的模型,称为TensorFlow决策森林(TFDF)。经过了2年的测试,他们在上个月终于宣布这个包已经准备好发布了,也就是说我们可以真正的开始使用了。所以这篇文章将详细介绍这个软件包,并向你展示如何(有效地)使用它。

在这篇文章中,我们将使用美国小企业管理局数据集训练一些贷款违约预测模型。模型将使用已经预处理的数据进行训练。安装TensorFlow决策森林非常简单,只需运行pip install tensorflow_decision_forests。

TensorFlow Decision Forest

1、什么是TFDF?

TensorFlow决策森林实际上是建立在c++的Yggdrasil决策森林之上库的,Yggdrasil决策森林也是由谷歌开发的。最初的c++算法旨在构建可扩展的决策树模型,可以处理大型数据集和高维特征空间。通过将这个库集成到更广泛的TF生态系统中,用户无需学习另一种语言就可以轻松地构建可扩展的RF和GBT模型。

2、为什么要用它?

与XGBoost或LightGBM相比,这个库的主要优势在于它与其他TF生态系统组件紧密集成。对于已经将其他TensorFlow模型作为管道的一部分或使用TFX的团队来说,这是非常有用的,因为TFDF可以很容易地与NLP模型集成。如果你正在使用TF Serving为模型提供对外服务,这个库也是可以用的,因为它是官方的原生支持(不需要ONNX或其他跨包序列化方法)模型的部署。最后这个库为还提供了大量参数,可以根据XGBoost、LightGBM和许多其他梯度增强机(GBM)方法来调整获得近似模型。这意味着不需要在训练过程中在不同的GBM库之间切换,这从代码可维护性的角度来说非常好。

模型训练

1、数据准备

我们使用了数据处理后的版本,所以不需要进行数据的预处理:

# Read in data
train_data: pd.DataFrame = pd.read_parquet("../data/train_data.parquet")
val_data: pd.DataFrame = pd.read_parquet("../data/val_data.parquet")
test_data: pd.DataFrame = pd.read_parquet("../data/test_data.parquet")
# Set data types
NUMERIC_FEATURES = [
"Term",
"NoEmp",
"CreateJob",
"RetainedJob",
"longitude",
"latitude",
"GrAppv",
"SBA_Appv",
"is_new",
"same_state",
]
CATEGORICAL_FEATURES = [
"FranchiseCode",
"UrbanRural",
"City",
"State",
"Bank",
"BankState",
"RevLineCr",
"naics_first_two",
]
TARGET = "is_default"
# Make sure that datatypes are consistent
dsets = [train_data,val_data,test_data]
for d in dsets:
d[NUMERIC_FEATURES] = d[NUMERIC_FEATURES].astype(np.float32)
d[CATEGORICAL_FEATURES] = d[CATEGORICAL_FEATURES].astype(str)

2、特征

为了确保项目结构良好并避免意外行为,可以为每个特性指定一个FeatureUsage,尽管这不是强制性的,但是使用这种方法可以让我们的项目更加规范。并且这也是一项简单的任务:只需要从支持的六种类型(BOOLEAN、CATEGORICAL、CATEGORICAL_SET、DISCRETIZED_NUMERICAL、HASH和NUMERICAL)中决定将哪些特征类型分配给哪个类型就可以了。其中一些类型带有额外的参数,所以请确保在这里阅读更多关于它们的信息。

在本例中我们将保持简单,只使用数值和类别数据类型,但是需要说明下DISCRETIZED_NUMERICAL,它可以显著加快训练过程(类似于LightGBM)。我们使用的代码如下,指定所选的数据类型,对于类别特征,还要指定min_vocab_frequency参数以去除罕见值。

import tensorflow_decision_forests as tfdf
# Prepare Feature Usage list
feature_usages = []
# Numerical features
for feature_name in NUMERIC_FEATURES:
feature_usage = tfdf.keras.FeatureUsage(
name=feature_name, semantic=tfdf.keras.FeatureSemantic.NUMERICAL
)
feature_usages.append(feature_usage)
# Categorical features
for feature_name in CATEGORICAL_FEATURES:
feature_usage = tfdf.keras.FeatureUsage(
name=feature_name,
semantic=tfdf.keras.FeatureSemantic.CATEGORICAL,
min_vocab_frequency=1000,
)
feature_usages.append(feature_usage)

3、使用TF Dataset读取数据

读取数据的最简单方法是使用TF Dataset。TFDF有一个非常好的实用函数pd_dataframe_to_tf_dataset,它使这一步变得非常简单。

# Use TF Dataset to read in data
train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
train_data, label=TARGET, weight=None, batch_size=1000
)
val_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
val_data, label=TARGET, weight=None, batch_size=1000
)
test_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(
test_data, label=TARGET, weight=None, batch_size=1000
)

在上面的代码中,我们将DataFrame对象传递到函数中,并提供以下参数:

  • label 列的名称
  • weight 列的名称(在本例中为None)
  • 批大小(有助于加快数据的读取)

得到的数据集是就是TF Dataset的格式。也可以创建自己的方法来读取数据集,但必须特别注意输出的格式,没有这个方法这样方便。

3、TFDF默认参数

## Define the models
# Gradient Boosted Trees model
gbt_model = tfdf.keras.GradientBoostedTreesModel(
features=feature_usages,
exclude_non_specified_features=True,
)
# Random Forest modle
rf_model = tfdf.keras.RandomForestModel(
features=feature_usages,
exclude_non_specified_features=True,
)
# Compile the models (Optional)
gbt_model.compile(metrics=[tf.keras.metrics.AUC(curve="PR")])
rf_model.compile(metrics=[tf.keras.metrics.AUC(curve="PR")])
# Fit the models
gbt_model.fit(train_dataset, validation_data=val_dataset)
rf_model.fit(train_dataset, validation_data=val_dataset)

只需要几行代码就可以使用默认参数构建和训练GBT和RF模型。当使用ROC和PR auc评估这两个模型时可以看到性能已经相当好了。

# GBT with Default Parameters
PR AUC: 0.8367
ROC AUC: 0.9583
# RF with Default Parameters
PR AUC: 0.8102
ROC AUC: 0.9453

那么是否可以进行超参数调优进一步改善这些结果呢。

超参数调优

Yggdrasil官方文档中有大量的参数可以进行调优,每一个参数都有很好的解释。TFDF也提供了一些内置选项来调优参数,为了简单也使用超参数的搜索库,例如Optuna或Hyperpot。

1、超参数模板

TFDF提供的非常好特性就是超参数模板。这些参数在论文中被证明在广泛的数据集上表现最好。有两个可用的模板:better_default和benchmark_rank。如果你时间不够,或者对机器学习不太熟悉,这是一个不错的选择。指定这些参数只需要一行代码。

# Define the models
better_default_gbt_model = tfdf.keras.GradientBoostedTreesModel(
hyperparameter_template='better_default', # template 1
features=feature_usages,
exclude_non_specified_features=True,
)
benchmark_gbt_model = tfdf.keras.GradientBoostedTreesModel(
hyperparameter_template='benchmark_rank1', # template 2
features=feature_usages,
exclude_non_specified_features=True,
)
# Fit the models (notice that we're skipping compiling step)
better_default_gbt_model.fit(train_dataset, validation_data=val_dataset)
benchmark_gbt_model.fit(train_dataset, validation_data=val_dataset)

看看结果怎么样:使用better_default参数,在ROC和PR auc中得到轻微的提升。benchmark_rank参数的性能要差得多。这就是为什么在部署结果模型之前正确地评估它们是很重要的。

GBT with 'Better Default' Parameters
PR AUC: 0.8483
ROC AUC: 0.9593
GBT with 'Benchmark Rank 1' Parameters
PR AUC: 0.7869
ROC AUC: 0.9442

2、定义搜索空间

TFDF附带了一个很好的程序,叫做RandomSearch,它在许多可用参数之间执行随机网格搜索。TFDF可以使用预定义的搜索空间或者通过一个选项可以手动指定这些参数(见示例)。如果您不太熟悉ML,这可能是一个很好的选择,因为它不需要手动设置这些参数。

# Create a Random Search tuner with 50 trials and automatic hp configuration.
tuner = tfdf.tuner.RandomSearch(num_trials=50, use_predefined_hps=True)
# Define and train the model.
tuned_model = tfdf.keras.GradientBoostedTreesModel(tuner=tuner)
tuned_model.fit(train_dataset, validation_data=val_dataset, verbose=2)

注意:无论那个超参数搜索都会耗费大量的时间,所以请根据模型谨慎选择。

进行完搜索后,可以使用下面的命令查看所有尝试过的组合。

tuning_logs = tuned_model.make_inspector().tuning_logs()

我们进行了12次迭代,最佳模型的表现比基线稍差,所以请谨慎使用内置的调优方法,建议使用其他库,比如Optuna。

PR AUC: 0.8216
ROC AUC: 0.9418

3、optuna

下面我们介绍如何使用optuna进行调优。

import optuna
def objective(trial: optuna.Trial) -> float:
params = {
"max_depth": trial.suggest_int("max_depth", 2, 10),
"l1_regularization": trial.suggest_float("l1_regularization", 0.01, 20),
"l2_regularization": trial.suggest_float("l2_regularization", 0.01, 20),
"growing_strategy": trial.suggest_categorical(
"growing_strategy", ["LOCAL", "BEST_FIRST_GLOBAL"]
),
"loss": trial.suggest_categorical(
"loss", ["BINOMIAL_LOG_LIKELIHOOD", "BINARY_FOCAL_LOSS"]
),
"min_examples": trial.suggest_int("min_examples", 5, 1000, step=5),
"focal_loss_alpha": trial.suggest_float("focal_loss_alpha", 0.05, 0.6),
"num_candidate_attributes_ratio": trial.suggest_float(
"num_candidate_attributes_ratio", 0.05, 0.95
),
"shrinkage": trial.suggest_float("shrinkage", 0.01, 0.9),
"early_stopping_num_trees_look_ahead": 50,
"num_trees": 2000,
}
model = tfdf.keras.GradientBoostedTreesModel(**params)
model.fit(train_dataset, validation_data=val_dataset, verbose=0)
preds = model.predict(val_dataset).ravel()
ap = average_precision_score(val_data[TARGET], preds)
return ap
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)

这些参数中的大多数对于gbt来说是相当标准的,但也有一些值得注意的参数:

  • 将growing_strategy更改为BEST_FIRST_GLOBAL(又称为按叶生长),这是LightGBM使用的策略。
  • 使用BINARY_FOCAL_LOSS,它应该对不平衡的数据集更好。
  • 更改split_axis参数以使用 sparse oblique splits,这在论文中证明是非常有效的。
  • 使用honest参数构建“honest trees”。

可以看到使用最佳参数获得的结果,自定义永远要比自动的好。

GBT with Custom Tuned Parameters
PR AUC: 0.8666
ROC AUC: 0.9631

现在我们已经确定了超参数,可以重新训练模型并继续进行我们的工作。

模型检验

TFDF提供了一个很好的实用工具来检查经过训练的模型,称为Inspector,他有3个主要用途:

  • 检查模型的属性,如类型,树的数量或使用的特征
  • 获取特性的重要性
  • 提取树结构

1、检查模型属性

inspector类存储了模型各种属性:模型类型(GBT或RF)、树的数量、训练目标以及用于训练模型的特征等等

inspector = manual_tuned.make_inspector()
print("Model type:", inspector.model_type())
print("Number of trees:", inspector.num_trees())
print("Objective:", inspector.objective())
print("Input features:", inspector.features())

或者使用manual_tuned.summary()来更详细地查看模型。

2、特征的重要性

像所有其他库一样,TFDF带有内置的特性重要性评分。对于gbt,可以访问NUM_NODES, SUM_SCORE, INV_MEAN_MIN_DEPTH, NUM_AS_ROOT方法。需要注意的是,可以在训练期间将compute_permutation_variable_importance参数设置为True,这将添加一些额外的方法,但是模型的训练速度会慢。

def plot_tfdf_importances(
inspector: tfdf.inspector.AbstractInspector, importance_type: str
):
"""Extracts and plots TFDF importances from the given inspector object
Args:
inspector (tfdf.inspector.AbstractInspector): inspector object created from your TFDF model
importance_type (str): importance type to plotß
"""
try:
importances = inspector.variable_importances()[importance_type]
except KeyError:
raise ValueError(
f"No {importance_type} importances found in the given inspector object"
)
names = []
scores = []
for f in importances:
names.append(f[0].name)
scores.append(f[1])
sns.barplot(x=scores, y=names, color="#5a7dbf")
plt.xlabel(importance_type)
plt.title("Variable Importance")
plt.show()

可以看到,Term变量一直是最重要的特征,紧随其后的是Bank、State和Bank State等类别变量。但是TFDF库最大的缺点之一是不能使用SHAP,这样可解释性的查看就有一些不方便。

3、检查树的结构

为了解释或模型验证,我们希望查看单个树。TFDF可以方便地访问所有树。比如GBT模型的第一棵树,因为它通常是信息量最大的树。

first_tree = inspector.extract_tree(tree_idx=0)
print(first_tree.pretty())

当处理较大的树时,使用print语句检查它们不太方便。TFDF提供了一个树绘图工具——TFDF .model_plotter

with open("plot.html", "w") as f:
f.write(tfdf.model_plotter.plot_model(manual_tuned, tree_idx=0, max_depth=4))

这样就方便很多了。

TF Serving

我们已经对模型进行了训练、调优和评估。最后的一个工作就是部署了,这部分也很简单,因为TFDF是官方的,必然会支持TF Serving。如果已经有了一个TF服务实例,那么所需要做的就是在model_base_path参数中指向要发布的模型。

首先就是保存我们的模型:

manual_tuned.save("../models/loan_default_model/1/")

然后就是在本地安装TF services,并使用正确的参数启动它。

./tensorflow_model_server \
--rest_api_port=8501 \
--model_name=loan_default_model \
--model_base_path=/path/models/loan_default_model/1

这里的model_base_path一定要是绝对路径。在TF服务服务器启动后,就可以开始接收请求了。有两种预期的格式——实例和输入。

# Input data formatted correctly
data = {
"Bank": ["Other"],
"BankState": ["TN"],
"City": ["Other"],
"CreateJob": [12.0],
"FranchiseCode": ["0"],
"GrAppv": [14900000.0],
"NoEmp": [28.0],
"RetainedJob": [16.0],
"RevLineCr": ["N"],
"SBA_Appv": [14900000.0],
"State": ["TN"],
"Term": [240.0],
"UrbanRural": ["0"],
"is_new": [0.0],
"latitude": [35.3468],
"longitude": [-86.22],
"naics_first_two": ["44"],
"same_state": [1.0],
"ApprovalFY": [1]
}
payload = {"inputs": data}
# Send the request
url = 'http://localhost:8501/v1/models/default_model:predict'
response = requests.post(url, json=payload)
# Print out the response
print(json.loads(response.text)['outputs'])
# Expected output: [[0.0138759678]]

返回的json串就包含了模型的预测结果。

总结

经过了2年的测试,TFDF终于发布正式版了。它是在TensorFlow中基于训练树的模型的一个强大且可扩展的库。TFDF模型与TensorFlow生态系统的其他部分很好地集成在一起,所以如果你正在使用TFX,在生产中有其他TF模型,或者正在使用TF服务,你会发现这个库非常有用。

本文的代码部分并不完全,如果你想自己探索,可以在这里下载完成的代码和数据集:

https://avoid.overfit.cn/post/db8eeaae665f410f90111b4ca0017b94

作者:Antons Tocilins-Ruberts

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
人民币对美元中间价报7.1192,在岸、离岸人民币汇率双双回升

人民币对美元中间价报7.1192,在岸、离岸人民币汇率双双回升

北京商报
2024-06-20 11:13:14
38岁已婚女与37岁情人,在石凳子上发生关系,温存后被残忍杀害

38岁已婚女与37岁情人,在石凳子上发生关系,温存后被残忍杀害

胖胖侃咖
2024-06-08 08:00:08
刚夺冠就被挖走,凯尔特人离队首人!5年3.15亿,塔图姆双喜临门

刚夺冠就被挖走,凯尔特人离队首人!5年3.15亿,塔图姆双喜临门

林子说事
2024-06-20 07:57:14
老美破防了!中企成功拿下高端光刻机,只为研发关键材料

老美破防了!中企成功拿下高端光刻机,只为研发关键材料

蚂蚁虾侃
2024-06-20 08:22:35
副董事长林斌被指违背承诺减持套现1.6亿 小米集团回应:是做公益

副董事长林斌被指违背承诺减持套现1.6亿 小米集团回应:是做公益

红星新闻
2024-06-20 12:26:10
浙江强降水落区有变!明起高温扩大,冲至36℃再猛降

浙江强降水落区有变!明起高温扩大,冲至36℃再猛降

鲁中晨报
2024-06-20 18:47:11
签约奖1.5亿欧!世体:巴萨与耐克续约谈判顺利 每年9000万固定奖

签约奖1.5亿欧!世体:巴萨与耐克续约谈判顺利 每年9000万固定奖

直播吧
2024-06-20 14:32:08
省委书记、省长为新机构揭牌!

省委书记、省长为新机构揭牌!

政知新媒体
2024-06-19 21:15:02
招聘(选聘)200余人!浙江一大批事业单位发布公告

招聘(选聘)200余人!浙江一大批事业单位发布公告

浙江发布
2024-06-20 15:00:09
难以相信!日本的国际形象很差吗?网友:佩服得五体投地!

难以相信!日本的国际形象很差吗?网友:佩服得五体投地!

有趣的羊驼
2024-06-06 14:22:52
桂林一药店女职员疑在洪水中触电身亡,附近商户:事后有居民提醒关闸

桂林一药店女职员疑在洪水中触电身亡,附近商户:事后有居民提醒关闸

鲁中晨报
2024-06-20 16:59:09
犹如高高在上的公主,让人愿意臣服。

犹如高高在上的公主,让人愿意臣服。

阿芒娱乐说
2024-05-29 16:23:53
房地产之后,中国即将出现2个新的造富机会:未来20年都是趋势!

房地产之后,中国即将出现2个新的造富机会:未来20年都是趋势!

森林聊商业
2023-06-21 10:10:18
上海白马会所:令无数富婆欲罢不能,19年因“头牌鸭王”一夜覆灭

上海白马会所:令无数富婆欲罢不能,19年因“头牌鸭王”一夜覆灭

V盟文史
2023-06-13 19:38:51
有那么一瞬间,我觉得大s的闹一点也不过分

有那么一瞬间,我觉得大s的闹一点也不过分

乐观探历史
2024-06-16 21:42:37
1726年,年羹尧得知自己将要被问斩,把怀孕小妾送给一秀才

1726年,年羹尧得知自己将要被问斩,把怀孕小妾送给一秀才

百态人间
2024-06-19 16:31:56
农行董事长谷澍:上海到店客户超40%是60岁以上人士,未来可能会进一步提升!需加强适老化改造

农行董事长谷澍:上海到店客户超40%是60岁以上人士,未来可能会进一步提升!需加强适老化改造

和讯网
2024-06-20 12:28:23
姜萍是否找人代考?达摩院澄清对姜萍很不利,浙大欲言又止

姜萍是否找人代考?达摩院澄清对姜萍很不利,浙大欲言又止

平老师666
2024-06-19 20:59:10
刘品言「难得关滤镜」不P图了!泳衣包不住半球 真实身材现形

刘品言「难得关滤镜」不P图了!泳衣包不住半球 真实身材现形

ETtoday星光云
2024-06-20 16:18:08
小S又曝大S劲爆旧料:去奢侈品购物被店员无视,一怒之下买了两个

小S又曝大S劲爆旧料:去奢侈品购物被店员无视,一怒之下买了两个

小徐讲八卦
2024-06-19 07:30:47
2024-06-21 02:30:44
deephub
deephub
CV NLP和数据挖掘知识
1373文章数 1416关注度
往期回顾 全部

科技要闻

小米SU7流量泼天,富贵却被蔚来接住了

头条要闻

媒体:以为中国会服软 菲在南海主权之争上存低级误判

头条要闻

媒体:以为中国会服软 菲在南海主权之争上存低级误判

体育要闻

千夫所指的关系户 成了拯救葡萄牙的英雄

娱乐要闻

叶舒华参加柯震东生日聚会,五毒俱全

财经要闻

楼市新“王炸”!释放何信号?

汽车要闻

售价11.79-14.39万元 新一代哈弗H6正式上市

态度原创

家居
艺术
旅游
数码
时尚

家居要闻

自然开放 实现灵动可变空间

艺术要闻

穿越时空的艺术:《马可·波罗》AI沉浸影片探索人类文明

旅游要闻

铁路儿童票新规 已有超4900万小旅客免费出行

数码要闻

AMD 发布 ROCm 6.1.3,支持 RX 7900 GRE 显卡及 TensorFlow

当男人不耍帅时,就是最帅的时候(穿衣篇)

无障碍浏览 进入关怀版