网易首页 > 网易号 > 正文 申请入驻

在表格数据集上训练变分自编码器 (VAE)示例

0
分享至

变分自编码器 (VAE) 是在图像数据应用中被提出,但是VAE不仅可以应用在图像中。在这篇文章中,我们将简单介绍什么是VAE,以及解释“为什么”变分自编码器是可以应用在数值类型的数据上,最后使用Numerai数据集展示“如何”训练它。

Numerai数据集数据集包含全球股市数十年的历史数据,在Numerai的锦标赛中,使用这个数据集来进行股票的投资收益预测和加密币NMR的收益预测。

一般来说 VAE 可以进行异常检测、去噪和生成合成数据。

异常检测

异常检测可以关于识别显着偏离大多数数据和不符合明确定义的正常行为概念的样本。 在 Numerai 数据集中这些异常可能是存在财务异常时期,检测到这些时期会为我们的预测提供额外的信息。

去噪

去噪是从信号中去除噪声的过程。 我们可以应用 VAE 对大多数偏离的特征进行降噪。 去噪转换噪声特征,一般情况下我们会将异常检测出的样本标记为噪声样本。

生成合成数据

使用 VAE,我们可以从正态分布中采样并将其传递给解码器以获得新的样本。

哪为什么选择变分自编码器呢?

自编码器由两个主要部分组成:

1)将输入映射为潜在空间的编码器

2)使用潜在空间重构输入的解码器

潜在空间在原论文中也被称为表示变量或潜在变量。那么为什么称为变分呢?将潜在表示的分布强制转换到一个已知的分布(如高斯分布),因为典型的自编码器不能控制潜在空间的分布而(VAE)提供了一种概率的方式来描述潜在空间中的观察。因此我们构建的编码器不是输出单个值来描述每个潜在空间的属性,而是用编码器来描述每个潜在属性的概率分布。在本文中我们使用了最原始的VAE,我们称之为vanilla VAE(以下称为原始VAE)

编码器由一个或多个全连接的层组成,其中最后一层输出正态分布的均值和方差。均值和方差值用于从相应的正态分布中采样,采样将作为输入到解码器。解码器由也是由一个或多个完全连接的层组成,并输出编码器输入的重建版本。下图展示了VAE的架构:

与普通自动编码器不同,VAE编码器模型将输出潜伏空间中每个维度的分布特征参数,而不是潜在空间的值。编码器将输出两个向量,反映潜在状态分布的均值和方差,因为我们假设先验具有正态分布。 然后,解码器模型将通过从这些定义的分布中采样来构建一个潜在向量,之后它将为解码器的输入重建原始输入。

普通 VAE 的损失函数中有两个项:1)重建误差和 2)KL 散度:

普通 VAE 中使用的重建误差是均方误差 (MSE)。 MSE 损失试图使重构的信号与输入信号相似性。 KL 散度损失试图使代码的分布接近正态分布。 q(z|x) 是给定输入信号的代码分布,p(z) 是正态分布。 PyTorch 代码如下所示:

recons_loss = F.mse_loss(recons, input)
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

原始VAE 配置如下所示:

model_params:
name: 'NumeraiHistogram of KL divergence (left) and mean-squared reconstruction lossVAE'
in_channels: 1191
latent_dim: 32
data_params:
data_path: "/train.parquet"
train_batch_size: 4096
val_batch_size: 4096
num_workers: 8
exp_params:
LR: 0.005
weight_decay: 0.0
scheduler_gamma: 0.95
kld_weight: 0.00025
manual_seed: 1265
trainer_params:
gpus: [1]
max_epochs: 300
logging_params:
save_dir: "logs/"
name: "NumeraiVAE"

配置中的关键参数有:

in_channels:输入特征的数量

latent_dim:VAE 的潜在维度。

编码器/解码器包括线性层,然后是批量归一化和leakyReLU 激活。

编码器的模型定义:

# Build Encoder
modules = []
modules.append(
nn.Sequential(
nn.Linear(in_channels, latent_dim),
nn.BatchNorm1d(latent_dim),
nn.LeakyReLU(),
))
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(latent_dim, latent_dim)
self.fc_var = nn.Linear(latent_dim, latent_dim)

解码器的模型定义:

# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, latent_dim)
modules.append(
nn.Sequential(
nn.Linear(latent_dim, in_channels),
nn.BatchNorm1d(in_channels),
nn.LeakyReLU()
))
self.decoder = nn.Sequential(*modules)

python3 run.py --config configs/numerai_vae.yaml

如果没有报错应该打印以下日志:

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
======= Training NumeraiVAE =======
Global seed set to 1265
initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]| Name | Type | Params
-------------------------------------
0 | model | NumeraiVAE | 83.1 K
-------------------------------------
83.1 K Trainable params
0 Non-trainable params
83.1 K Total params
0.332 Total estimated model params size (MB)
Global seed set to 1265
Epoch 19: 100%|██████████████████████████████████████████████████████████████████████████| 592/592 [00:20<00:00, 28.49it/s, loss=0.0818, v_num=3]

如何使用 VAE 进行异常检测?

异常是具有高损失值的样本。 损失值可以是重建损失、KL散度损失或它们的组合。

Numerai 训练数据集上的 KL 散度的直方图

这是MSE损失的直方图。

下图是Numerai 训练数据集的 KL 散度和均方误差的可视化。该图训练后的 VAE 的潜在维度为 2,因此我们可以将其可视化。

如何用 VAE 去噪?

首先将带有噪声的输入传递给编码器以获取潜在空间。 然后将潜在空间传递给解码器以获得去噪后输入(重建输入)。

如何使用 VAE 生成合成数据?

由于解码器的输入遵循已知分布(即高斯分布),我们可以从高斯分布中采样并将值传递给解码器就可以获得新的合成数据。

https://avoid.overfit.cn/post/144af920f43240be9ed07f0a8e0d6051

作者:Amir Erfan Eshratifar

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
广州当年的战略短视苦果,从放手东莞的那一刻就开始了!

广州当年的战略短视苦果,从放手东莞的那一刻就开始了!

元爸体育
2024-06-15 21:12:43
都是产能惹的祸?小米汽车卖爆了,雷军压力反而更大了

都是产能惹的祸?小米汽车卖爆了,雷军压力反而更大了

电车通
2024-06-14 21:40:28
路易小王子在庆典上表情包又火了!差点抢了凯特风头,还打哈欠呢

路易小王子在庆典上表情包又火了!差点抢了凯特风头,还打哈欠呢

童童聊娱乐啊
2024-06-15 23:14:36
3年1.64亿,续约湖人!钱我要,冠军我也要!你打破了全联盟的质疑

3年1.64亿,续约湖人!钱我要,冠军我也要!你打破了全联盟的质疑

开心体育站
2024-06-15 22:58:25
全球第12名 !中专女生姜萍决赛能拿奖吗?组委会:决赛难度博士级

全球第12名 !中专女生姜萍决赛能拿奖吗?组委会:决赛难度博士级

小李子体育
2024-06-15 19:15:00
放心,不会有古巴导弹危机2.0版了

放心,不会有古巴导弹危机2.0版了

报人刘亚东
2024-06-14 20:11:22
尘埃落定,瓜帅返回!不回老东家,曼城开启B计划,英超大结局

尘埃落定,瓜帅返回!不回老东家,曼城开启B计划,英超大结局

阿泰希特
2024-06-15 21:26:57
杨健:独行侠一场暴胜之后 更感到G3的策略失误是致命性的打击

杨健:独行侠一场暴胜之后 更感到G3的策略失误是致命性的打击

直播吧
2024-06-15 11:22:14
震撼!中国终于宣告收回被占领70年的领土,背后的故事让人震惊!

震撼!中国终于宣告收回被占领70年的领土,背后的故事让人震惊!

趣说世界哈
2024-06-13 11:25:11
告诉大家一个消息!2024年,手握2套房子以上的家庭或逃不掉

告诉大家一个消息!2024年,手握2套房子以上的家庭或逃不掉

山丘楼评
2024-05-26 11:42:59
惊呆!这里房价暴跌70%,真应了马云说的白菜价

惊呆!这里房价暴跌70%,真应了马云说的白菜价

山丘楼评
2024-06-14 23:55:48
上海失踪女童遗体已找到 排除刑事案件

上海失踪女童遗体已找到 排除刑事案件

北青网-北京青年报
2024-06-15 17:05:07
窦佳嫄从来没叫过窦唯爸爸,也没送过四合院,上一次见面是两年前

窦佳嫄从来没叫过窦唯爸爸,也没送过四合院,上一次见面是两年前

五四观娱
2024-06-10 22:41:43
余承东该紧张了,问界销量,落后理想越来越远了

余承东该紧张了,问界销量,落后理想越来越远了

互联网.乱侃秀
2024-06-14 09:56:25
陈小春父亲节晒全家福,41岁应采儿提前庆生,Jasper高冷弟弟好萌

陈小春父亲节晒全家福,41岁应采儿提前庆生,Jasper高冷弟弟好萌

娱絮
2024-06-16 03:02:09
风波升级!央视六公主深夜再放证据,打脸粉丝,周也踢到铁板了!

风波升级!央视六公主深夜再放证据,打脸粉丝,周也踢到铁板了!

王小乖
2024-06-15 16:08:03
买房人开始崩溃了!

买房人开始崩溃了!

山丘楼评
2024-06-13 23:05:40
老百姓傻眼了?全国铁饭碗人数加起来都不到一个亿!

老百姓傻眼了?全国铁饭碗人数加起来都不到一个亿!

娱乐洞察点点
2024-06-15 16:37:29
汪小菲已在台北给马筱梅买房,大S一改从前,换了一种方式反击!

汪小菲已在台北给马筱梅买房,大S一改从前,换了一种方式反击!

鑫鑫说说
2024-06-14 13:47:35
6月15日俄乌:92国出席瑞士和平峰会,G7有收获,普京提和谈条件

6月15日俄乌:92国出席瑞士和平峰会,G7有收获,普京提和谈条件

山河路口
2024-06-15 14:28:58
2024-06-16 08:54:44
deephub
deephub
CV NLP和数据挖掘知识
1368文章数 1416关注度
往期回顾 全部

科技要闻

TikTok开始找退路了?

头条要闻

牛弹琴:梅洛尼和马克龙吵了一架 晚宴上眼神可"杀人"

头条要闻

牛弹琴:梅洛尼和马克龙吵了一架 晚宴上眼神可"杀人"

体育要闻

莱夫利,让困难为我让路

娱乐要闻

江宏杰秀儿女刺青,不怕刺激福原爱?

财经要闻

新情况!高层对人民币的态度180°转弯

汽车要闻

东风奕派eπ008售21.66万元 冰箱彩电都配齐

态度原创

时尚
家居
数码
健康
公开课

可以轻松借鉴的通勤装扮,女人多穿“过膝裙”,优雅时尚大气

家居要闻

空谷来音 朴素留白的侘寂之美

数码要闻

OPPO 新款 100W SUPERVOOC 快充移动电源通过 3C 认证

晚餐不吃or吃七分饱,哪种更减肥?

公开课

近视只是视力差?小心并发症

无障碍浏览 进入关怀版