网易首页 > 网易号 > 正文 申请入驻

混合密度网络(MDN)进行多元回归详解和代码示例

0
分享至

在本文中,首先简要解释一下 混合密度网络 MDN (Mixture Density Network)是什么,然后将使用Python 代码构建 MDN 模型,最后使用构建好的模型进行多元回归并测试效果。

回归

“回归预测建模是逼近从输入变量 (X) 到连续输出变量 (y) 的映射函数 (f) [...] 回归问题需要预测具体的数值。 具有多个输入变量的问题通常被称为多元回归问题 例如,预测房屋价值,可能在 100,000 美元到 200,000 美元之间

这是另一个区分分类问题和回归问题的视觉解释如下:

另外一个例子

密度

DENSITY “密度” 是什么意思? 这是一个快速的通俗示例:

假设正在为必胜客运送比萨。 现在记录刚刚进行的每次交付的时间(以分钟为单位)。 交付 1000 次后,将数据可视化以查看工作表现如何。 这是结果:

这是披萨交付时间数据分布的“密度”。平均而言,每次交付需要 30 分钟(图中的峰值)。 它还表示,在 95% 的情况下(2 个标准差2sd ),交付需要 20 到 40 分钟才能完成。 密度种类代表时间结果的“频率”。 “频率”和“密度”的区别在于:

· 频率:如果你在这条曲线下绘制一个直方图并对所有的 bin 进行计数,它将求和为任何整数(取决于数据集中捕获的观察总数)。

· 密度:如果你在这条曲线下绘制一个直方图并计算所有的 bin,它总和为 1。我们也可以将此曲线称为概率密度函数 (pdf)。

用统计术语来说,这是一个漂亮的正态/高斯分布。 这个正态分布有两个参数:

均值

· 标准差:“标准差是一个数字,用于说明一组测量值如何从平均值(平均值)或预期值中展开。低标准偏差意味着大多数数字接近平均值。高标准差意味着数字更加分散。“

均值和标准差的变化会影响分布的形状。 例如:

有许多具有不同类型参数的各种不同分布类型。 例如:

混合密度

现在让我们看看这 3 个分布:

如果我们采用这种双峰分布(也称为一般分布):

网络架构

混合密度网络使用这样的假设,即任何像这种双峰分布的一般分布都可以分解为正态分布的混合(该混合也可以与其他类型的分布一起定制 例如拉普拉斯):

混合密度网络也是一种人工神经网络。 这是神经网络的经典示例:

输入层(黄色)、隐藏层(绿色)和输出层(红色)。

如果我们将神经网络的目标定义为学习在给定一些输入特征的情况下输出连续值。 在上面的例子中,给定年龄、性别、教育程度和其他特征,那么神经网络就可以进行回归的运算。

密度网络

密度网络也是神经网络,其目标不是简单地学习输出单个连续值,而是学习在给定一些输入特征的情况下输出分布参数(此处为均值和标准差)。 在上面的例子中,给定年龄、性别、教育程度等特征,神经网络学习预测期望工资分布的均值和标准差。预测分布比预测单个值具有很多的优势,例如能够给出预测的不确定性边界。 这是解决回归问题的“贝叶斯”方法。下面是预测每个预期连续值的分布的一个很好的例子:

下面的图片向我们展示了每个预测实例的预期值分布:

混合密度网络

最后回到正题,混合密度网络的目标是在给定特定输入特征的情况下,学习输出混合在一般分布中的所有分布的参数(此处为均值、标准差和 Pi)。 新参数“Pi”是混合参数,它给出最终混合中给定分布的权重/概率。

最终结果如下:

示例1:单变量数据的 MDN 类

上面的定义和理论基础已经介绍完毕,下面我们开始代码的演示:

import numpy as np
import pandas as pd
from mdn_model import MDN
from sklearn.datasets import make_moons
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LinearRegression
from sklearn.kernel_ridge import KernelRidge
plt.style.use('ggplot')

生成著名的“半月”型的数据集:

X, y = make_moons(n_samples=2500, noise=0.03)
y = X[:, 1].reshape(-1,1)
X = X[:, 0].reshape(-1,1)
x_scaler = StandardScaler()
y_scaler = StandardScaler()
X = x_scaler.fit_transform(X)
y = y_scaler.fit_transform(y)
plt.scatter(X, y, alpha = 0.3)

绘制目标值 (y) 的密度分布:

sns.kdeplot(y.ravel(), shade=True)

通过查看数据,我们可以看到有两个重叠的簇:

这时一个很好的多模态分布(一般分布)。 如果我们在这个数据集上尝试一个标准的线性回归来用 X 预测 y:

model = LinearRegression()
model.fit(X.reshape(-1,1), y.reshape(-1,1))
y_pred = model.predict(X.reshape(-1,1))
plt.scatter(X, y, alpha = 0.3)
plt.scatter(X,y_pred)
plt.title('Linear Regression')

sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'Linear Pred dist') sns.kdeplot(y.ravel(), shade=True, label = 'True dist')

效果必须不好!现在让尝试一个非线性模型(径向基函数核岭回归):

model = KernelRidge(kernel = 'rbf')
model.fit(X, y)
y_pred = model.predict(X)
plt.scatter(X, y, alpha = 0.3)
plt.scatter(X,y_pred)
plt.title('Non Linear Regression')

sns.kdeplot(y_pred.ravel(), shade=True, alpha = 0.15, label = 'NonLinear Pred dist') sns.kdeplot(y.ravel(), shade=True, label = 'True dist')

虽然结果也不尽如人意,但是比上面的线性回归要好很多了。

两个模型都没有成功的主要原因是:对于同一个 X 值存在多个不同的 y 值……更具体地说,对于同一个 X 似乎存在不止一个可能的 y 分布。 回归模型只是试图找到最小化误差的最优函数,并没有考虑到密度的混合,所以 中间的那些X没有唯一的Y解,它们有两种可能的解,所以导致了以上的问题。

现在让我们尝试一个 MDN 模型,这里已经实现了一个快速且易于使用的“fit-predict”、“sklearn alike”自定义 python MDN 类。 如果您想自己使用它,本文最后会附带 python 代码的链接(请注意:这个 MDN 类是实验性的,尚未经过广泛测试)

为了能够使用这个类,有 sklearn、tensorflow probability、Tensorflow < 2、umap 和 hdbscan(用于自定义可视化类 功能)。

EPOCHS = 10000
BATCH_SIZE=len(X)
model = MDN(n_mixtures = -1,
dist = 'laplace',
input_neurons = 1000,
hidden_neurons = [25],
gmm_boost = False,
optimizer = 'adam',
learning_rate = 0.001,
early_stopping = 250,
tf_mixture_family = True,
input_activation = 'relu',
hidden_activation = 'leaky_relu')
model.fit(X, y, epochs = EPOCHS, batch_size = BATCH_SIZE)

类的参数总结如下:

· n_mixtures:MDN 使用的分布混合数。如果设置为 -1,它将使用高斯混合模型 (GMM) 和 X 和 y 上的 HDBSCAN 模型“自动”找到最佳混合数。

· dist:在混合中使用的分布类型。目前,有两种选择; “正常”或“拉普拉斯”。 (基于一些实验,拉普拉斯分布比正态分布更好的结果)。

· input_neurons:在MDN的输入层中使用的神经元数量

· hidden_neurons:MDN的 隐藏层架构。每个隐藏层的神经元列表。此参数使您能够选择隐藏层的数量和每个隐藏层的神经元数量。

· gmm_boost:布尔值。如果设置为 True,将向数据集添加簇特征。

· optimizer:要使用的优化算法。

· learning_rate:优化算法的学习率

· early_stopping:避免训练时过拟合。当指标在给定数量的时期内没有变化时,此触发器将决定何时停止训练。

· tf_mixture_family:布尔值。如果设置为 True,将使用 tf_mixture 系列(推荐):Mixture 对象实现批量混合分布。

· input_activation:输入层的激活函数

· hidden_activation:隐藏层的激活函数

现在 MDN 模型已经拟合了数据,从混合密度分布中采样并绘制概率密度函数:

model.plot_distribution_fit(n_samples_batch = 1)

我们的 MDN 模型非常适合真正的一般分布!下面将最终的混合分布分解为每个分布,看看它的样子:

model.plot_all_distribution_fit(n_samples_batch = 1)

使用学习到的混合分布再次采样一些 Y 数据,生成的样本与真实样本进行对比:

model.plot_samples_vs_true(X, y, alpha = 0.2)

与实际的数据非常接近,如果,给定 X还可以生成多批样本以生成分位数、均值等统计信息:

generated_samples = model.sample_from_mixture(X, n_samples_batch = 10) generated_samples

绘制每个学习分布的平均值,以及它们各自的混合权重 (pi):

plt.scatter(X, y, alpha = 0.2) model.plot_predict_dist(X, with_weights = True, size = 250)

有每个分布的均值和标准差,还可以绘制带有完整不确定性; 假设我们以 95% 的置信区间绘制平均值:

plt.scatter(X, y, alpha = 0.2) model.plot_predict_dist(X, q = 0.95, with_weights = False)

将分布混合在一起,当对同一个 X 有多个 y 分布时,我们使用最高 Pi 参数值选择最可能的混合:

Y_preds = 对于每个 X,选择具有最大概率/权重(Pi 参数)的分布的 Y 均值

plt.scatter(X, y, alpha = 0.3) model.plot_predict_best(X)

这种方式表现得并不理想,因为在数据中显然有两个不同的簇重叠,密度几乎相等。 使得误差将高于标准回归模型。 这也意味着数据集中可能缺少一个可以帮助避免集群在更高维度上重叠重要特征。

我们还可以选择使用 Pi 参数和所有分布的均值混合分布:

· Y_preds = (mean_1 * Pi1) + (mean_2 * Pi2)

plt.scatter(X, y, alpha = 0.3) model.plot_predict_mixed(X)

如果我们添加 95 置信区间:

这个选项提供了与非线性回归模型几乎相同的结果,混合所有内容以最小化点和函数之间的距离。 在这个非常特殊的情况下,我最喜欢的选择是假设在数据的某些区域,X 有多个 Y,而在其他区域; 仅使用其中一种混合。:

例如,当 X = 0 时,每种混合可能有两种不同的 Y 解。 当 X = -1.5 时,混合 1 中存在唯一的 Y 解决方案。根据用例或业务上下文,当同一个 X 存在多个解决方案时,可以触发操作或决策。

这个选项得含义是当存在重叠分布时(如果两个混合概率都 >= 给定概率阈值),行将被复制:

plt.scatter(X, y, alpha = 0.3) model.plot_predict_with_overlaps(X)

使用 95% 置信区间:

数据集行从 2500 增加到了 4063,最终预测数据集如下所示:

在这个数据表中,当 X = -0.276839 时,Y 可以是 1.43926(混合_0 的概率为 0.351525),但也可以是 -0.840593(混合_1 的概率为 0.648475)。

具有多个分布的实例还提供了重要信息,即数据中正在发生某些事情,并且可能需要更多分析。可能是一些数据质量问题,或者可能表明数据集中缺少一个重要特征!

“交通场景预测是可以使用混合密度网络的一个很好的例子。在交通场景预测中,我们需要一个可以表现出的行为分布——例如,一个代理可以左转、右转或直行。因此,混合密度网络可用于表示它学习的每个混合中的“行为”,其中行为由概率和轨迹组成((x,y)坐标在未来某个时间范围内)。

示例2:具有MDN 的多变量回归

最后MDN 在多元回归问题上表现良好吗?

我们将使用以下的数据集:

· 年龄:主要受益人的年龄

· 性别:保险承包商性别,女,男

· bmi:体重指数,提供对身体的了解,相对于身高相对较高或较低的体重,使用身高与体重之比的体重客观指数(kg / m ^ 2),理想情况下为18.5到24.9

· 子女:健康保险覆盖的子女人数/受抚养人人数

· 吸烟者:吸烟

· 地区:受益人在美国、东北、东南、西南、西北的居住区。

· 费用:由健康保险计费的个人医疗费用。 这是我们要预测的目标

问题陈述是:能否准确预测保险费用(收费)?

现在,让我们导入数据集,训练完成后使用“最佳混合概率(Pi 参数)策略”预测测试数据集并绘制结果(y_pred vs y_test):

y_pred = model.predict_best(X_test, q = 0.95, y_scaler = y_scaler) model.plot_pred_fit(y_pred, y_test, y_scaler = y_scaler)

model.plot_pred_vs_true(y_pred, y_test, y_scaler = y_scaler)

R2 为 89.09,MAE 为 882.54,MDN太棒了,让我们绘制拟合分布与真实分布的图来进行对比:

model.plot_distribution_fit(n_samples_batch = 1)

几乎一模一样!分解混合模型,看看什么情况:

一共混合了六种不同的分布。

从拟合的混合模型生成多变量样本(应用 PCA 以在 2D 中可视化结果):

model.plot_samples_vs_true(X_test, y_test, alpha = 0.35, y_scaler = y_scaler)

生成的样本与真实样本非常接近!如果我们愿意,还可以从每个分布中进行预测:

y_pred_dist = model.predict_dist(X_test, q = 0.95, y_scaler = y_scaler) y_pred_dist

总结

· 与线性或非线性经典 ML 模型相比,MDN 在单变量回归数据集中表现出色,其中两个簇相互重叠,并且 X 可能有多个 Y 输出。

· MDN 在多元回归问题上也做得很好,可以与 XGBoost 等流行模型竞争

· MDN 是 ML 中的一款出色且独特的工具,可以解决其他模型无法解决的特定问题(能够从混合分布中获得的数据中学习)

· 随着 MDN 学习分布,还可以通过预测计算不确定性或从学习的分布中生成新样本

本文的代码非常的多,这里是完整的notebook,可以直接下载运行:

https://www.overfit.cn/post/20245a8446ae43e3982b48e4320991ab

作者:Dave Cote, M.Sc.

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
太突然!知名品牌宣布关闭在中国所有线上线下店铺,店员:正2折起清仓!入华近20年,1月刚从纽交所私有化退市,网友:有点可惜

太突然!知名品牌宣布关闭在中国所有线上线下店铺,店员:正2折起清仓!入华近20年,1月刚从纽交所私有化退市,网友:有点可惜

每日经济新闻
2026-02-28 14:51:10
美以袭击致伊朗201人死747人伤

美以袭击致伊朗201人死747人伤

财联社
2026-03-01 01:44:07
内塔尼亚胡:诸多迹象显示哈梅内伊已“不在人世”

内塔尼亚胡:诸多迹象显示哈梅内伊已“不在人世”

财联社
2026-03-01 03:39:04
伊朗别慌,三招反击美以联军,实用管用

伊朗别慌,三招反击美以联军,实用管用

今日马说
2026-02-28 20:01:39
马斯克藏太深!美星链离不开7家中国公司,每一家都是全球顶尖!

马斯克藏太深!美星链离不开7家中国公司,每一家都是全球顶尖!

爱吃醋的猫咪
2026-02-27 17:56:07
后续!绝情臭豆腐最新进展:负责人正脸曝光社死,店老板公开道歉

后续!绝情臭豆腐最新进展:负责人正脸曝光社死,店老板公开道歉

离离言几许
2026-02-28 18:08:15
死了白死?俄方绝口不提贝加尔湖惨剧赔偿 中日美此类案例都有赔款

死了白死?俄方绝口不提贝加尔湖惨剧赔偿 中日美此类案例都有赔款

劲爆体坛
2026-02-28 18:30:13
美伊冲突引市场巨震,油价或历史性飙升,未来48小时决定金价走势,比特币近15万人爆仓

美伊冲突引市场巨震,油价或历史性飙升,未来48小时决定金价走势,比特币近15万人爆仓

21世纪经济报道
2026-02-28 23:56:17
伊朗伊通社网站恢复正常运行

伊朗伊通社网站恢复正常运行

环球网资讯
2026-02-28 16:07:07
美国为什么不敢打伊朗?专家的预测又被打脸了

美国为什么不敢打伊朗?专家的预测又被打脸了

历史总在押韵
2026-02-28 23:31:28
四强又只剩王楚钦了!7人相继被淘汰,林诗栋引失望,陈垣宇惊喜

四强又只剩王楚钦了!7人相继被淘汰,林诗栋引失望,陈垣宇惊喜

篮球资讯达人
2026-02-28 22:31:49
知名演员秦岚自曝患病,已做手术!

知名演员秦岚自曝患病,已做手术!

极目新闻
2026-02-28 23:12:57
特朗普突然发文昭告全球,包括中国俄罗斯在内,这次一个都跑不掉

特朗普突然发文昭告全球,包括中国俄罗斯在内,这次一个都跑不掉

带你领略快乐真谛
2026-02-28 16:55:50
“大力神”军机坠毁 已致15人死亡 天空下钞票雨 民众疯抢!

“大力神”军机坠毁 已致15人死亡 天空下钞票雨 民众疯抢!

每日经济新闻
2026-02-28 14:37:58
多数珠宝品牌商现已不回收白银,部分周大福门店还表示目前黄金回收也已暂停

多数珠宝品牌商现已不回收白银,部分周大福门店还表示目前黄金回收也已暂停

黄河新闻网吕梁
2026-02-28 09:13:37
女子回湖北婆家过年,车被妯娌砸稀烂,报警后绝不和解,结局爽了

女子回湖北婆家过年,车被妯娌砸稀烂,报警后绝不和解,结局爽了

不写散文诗
2026-02-28 17:19:21
“重大作战”,要打多久?

“重大作战”,要打多久?

中国新闻周刊
2026-02-28 20:19:57
伊朗第七轮导弹射向以色列

伊朗第七轮导弹射向以色列

界面新闻
2026-02-28 20:30:44
扛不住了,江苏某大型建设集团全员息岗!

扛不住了,江苏某大型建设集团全员息岗!

黯泉
2026-02-28 22:39:41
金融圈突发!涉嫌严重违纪违法,金春花被查

金融圈突发!涉嫌严重违纪违法,金春花被查

中国基金报
2026-02-28 17:17:02
2026-03-01 05:55:00
deephub incentive-icons
deephub
CV NLP和数据挖掘知识
1934文章数 1456关注度
往期回顾 全部

科技要闻

狂揽1100亿美元!OpenAI再创融资神话

头条要闻

以官员称哈梅内伊身亡 遗体在其官邸废墟中被找到

头条要闻

以官员称哈梅内伊身亡 遗体在其官邸废墟中被找到

体育要闻

球队主力全报销?顶风摆烂演都不演了

娱乐要闻

周杰伦儿子正面照曝光,与父亲好像

财经要闻

冲突爆发 市场变天?

汽车要闻

岚图泰山黑武士版3月上市 搭载华为四激光智驾方案

态度原创

手机
游戏
亲子
公开课
军事航空

手机要闻

澎湃OS再次公布进展通报:10个问题,仅修复一则!

所有人保持嘴角不变!生化危机:安魂曲里昂骚话大盘点

亲子要闻

婴儿吃的“洋”辅食,九成靠代工贴牌?涉及上市公司

公开课

李玫瑾:为什么性格比能力更重要?

军事要闻

美国以色列联合袭击伊朗 实时战况

无障碍浏览 进入关怀版