来给大家介绍一下我们被接收为ICLR Spotlight的新工作。这个work从2021年春开始一直做到2022年秋,中间克服了许多技术障碍,没想到第一次投稿就好评如潮(分数8886),也恭喜子昊的坚持得到回报。这篇工作的核心贡献在于,正式定义了domain adaptation中的域索引(domain index),精心设计了推断(infer)domain index的算法(variational domain indexing, 即VDI),并且证明了我们的算法可以推断出最优的domain index。由于推断出来的domain index带来的free lunch,domain adaptation的性能也得到了提高。
论文链接: http://wanghao.in/paper/ICLR23_VDI.pdf 代码链接: https://github.com/Wang-ML-Lab/VDI
什么是domain index:domain index的说法最早在我们的ICML 2020论文“Continuously Indexed Domain Adaptation”(CIDA)中提出(有兴趣的看官欢迎移步我们讲CIDA的知乎帖子)。最直观的例子就是在医疗应用里面,不同年龄的人可以看成是不同的domain,而这个“年龄”其实就是domain的一个索引(index),也就是我们说的domain index(域索引)。如下图。
有意思的是,domain index其实是一个连续的概念,所以自然而然地包含了domain的远近信息。比如上面说的“年龄”可以作为一个一维的domain index,年龄18和19距离很近,而18和80却距离很远。我们之前在CIDA(大致的CIDA模型如下图)上的实验发现,如果已知这个domain index,我们可以很好地做到连续域上的domain adaptation,从而大幅提高准确率。
比如把模型从年龄0 ~ 20的病人(source domains),adapt到年龄20 ~ 80的病人(target domains),或者从年龄0 ~ 15以及50 ~ 80的病人(source domains),adapt到年龄15 ~ 50的病人(target domains),如下图。
那么问题来了,如果这个domain index是未知的,咋办? 最理想的情况当然是,我们能够把这个domain index作为隐变量(latent variable),通过 无监督 (unsupervised)的方式把它推断(infer)出来。如果这个方案可行,我们就免费拿到了一个重要的额外的信息,从而既可以提高domain adaptation的 准确率 ,又能提高它的 可解释性。
Domain Index的正式定义 : 在推断domain index前,我们要先定义清楚,什么才算是domain index,然后才能设计推断它的方法。
这里我们首先引入了两种domain index, local domain index(用u表示)和global domain index(用β表示)。 我们规定,虽然同一个domain里的不同数据点(data point)可以有不同的local domain index,但是同一个domain里的所有数据的global domain index必须是是 相同的 。也就是说,local domain index是一个instance-level的变量,而global domain index是一个domain-level的变量。下面的图是一个具体的例子,展示了global domain index β、local domain index u、数据x之间的关系。
那么符合什么条件的u和beta才能被叫做domain index呢?我们定义了三个条件(这里x表示数据,y表示标签,z表示x经过encoder后得到的encoding):
z和β的条件独立:Encoding z和global domain index β是条件独立的。换句话说,他们的互信息 β 必须是0。
保留x的信息:Encoding z,global domain index β,和local domain index u这三组变量 比如尽可能地保留数据x的信息。换句话说,他们的互信息 β 必须达到最大
z对标签y的敏感度:Encoding z要尽可能保留标签y的信息(这样才能提高预测y的准确率)。这意味着他们的互信息 必须达到最大 。
如果 β 和 u 满足上述三个条件,我们就把它们分别称为global domain index和local domain index。这三个条件可以用下面的数学公式表示:
方法的整体思路 : 定义完domain index后,下一个问题自然就是,如何能在无监督(完全不知道domain index)的情况下,有效地推断出符合上面三段定义的domain index β 和 u 呢?这时,就要请出adversarial Bayesian deep learning model(对Bayesian deep learning感兴趣的同学可以看看我们之前的帖子)来解决这个问题。
在Bayesian deep learning里面,或者更加传统的probabilistic graphical model里面,我们会分两步走:
第一步 是首先假设一下 已知变量(observed variable) 是如何从 隐变量(latent variable,即未知的变量) 一步步生成的。我们一般把这个叫做 生成过程 (generative process)。
然后 第二步 ,就是通过贝叶斯推断(Bayesian inference)的方式来根据已知变量来倒推隐变量。
在我们目前的问题里,数据x以及标签y都是已知变量,而我们的encoding z以及domain index β和u则是 隐变量 。那么很自然,我们的目的就是已知各个domain里的 数据x以及标签y ,然后想推断出encoding z以及domain index β和u。注意,在domain adaptation里面,只有source domain才有已知的标签y。target domain只有数据x。
生成过程: 根据这个整体思路,我们就首先假设一下各个变量 生成过程 (如下图左边):
对于每个 domain :
从高斯分布 β α 中生成一个global domain index β ,
对于domain k中的每个数据点 ,
从高斯分布 β 中生成一个localdomain index ,
从高斯分布 中生成数据
从高斯分布 β 中生成encoding
从分 布 生成标签
用变分分布估计后验概率: 用变分分布估计后验概率:有了这个生成过程,我们就可以开始思考如何推断(infer) 出每个数据 对应的 encoding 及其domain index β 和 u 。我们首先会先构造一些变分分布(variational distribution),通过学些这些变分分布来推断 、β 和 u 。比如,如果我们会学会了变分分布 ,那么,给定一个数据 ,我们就能根据 得到local domain index 了。
在我们的方法里面我们一共定义了3个变分分布: , β ,和 β 。这里对应着上图的右边。在这几个分布里面,比较关键的是分布 β ,它会对同一个domain下所有数据的local domain index做一个聚合(aggregation),来推断这个domain的global domain index。注意每个数据都有自己的不同的local domain index,而同一个domain里的所有数据只共享同一个global domain index。这里的 {u}的大括号表示的是同一个domain里所有data对应的所有local domain index u组成的集合。在推断global domain index时,我们还在u的集合上应用了optimal transport,有兴趣的同学可以看下论文原文的细节。
Evidence Lower Bound(ELBO):接下来就是用ELBO把5个生成分布 β α , β , p(x_i∣u_i) , β , 和3个变分分布 , β , β 串成下面的目标函数:
从变分(variational inference)的角度,最大化上面的ELBO,等价于在寻找最优的变分分布 , β , β 来估计 , β , 的真实分布。
上面的目标函数可能有点冗长难懂,直接看 下图 可能会好些。直观地讲,我们可以把优化这个ELBO,看成 学习很多子网络 来对输入数据x进行 编码 (encode)和 重构 (reconstruct)的过程,关键在于,在这个编码和重构的过程中,需要 聪明地把domain index β 和 u 建模进去。
对贝叶斯推断(Bayesian Inference)熟悉的同学可能已经发现了,这个其实就是我们之前说的(广义的)贝叶斯深度学习(Bayesian Deep Learning)的思路:用 深度模块 (deep component)来处理高维信号x(比如图片),然后用 概率图模块 (graphical component)来表示各个随机变量之间的条件概率关系(比如图片x及其对应的encoding z和domain index β、u的关系)。
回到Domain Index的三段定义:讲到这里,眼尖的同学可能会发现,虽然最大化这个ELBO目标函数确实可能可以符合前面说的domain index的三个要求中的后两个,即保留x的信息(最大化互信息 β ) 和z对标签y的敏感度(最大化互 信息 ) , 但是却忽略了第一个要求,即z和β的条件独立(互信息 β )。
为了满足第一个要求,我们需要借鉴 对抗域迁移 (adversarial domain adaptation)的思想,在上图的基础上,再加上一个discriminator,然后对抗地(adversarially)训练整个网络,使得encoder能把不同domain的x映射到一个encoding空间,然后让这个discriminator无法从他们的encoding z来分辨出数据是来自于哪个domain的。我们把这个操作叫做encoding的对齐(alignment),即把不同的domain的encoding分布对齐起来,让他们互相重叠,这样就可以方便不同domain共享一个predictor了(比如分类器或者回归器)。加上discriminator之后的神经网络架构如下:。
最终的目标函数: 相应地,我们最终的目标函数也从一个简单的优化问题(最大化ELBO)变成了一个minimax game:
理论保障 : 有趣的是,我们可以严格地证明,上面的目标函数的全局最优点正好就可以同时满足我们对 domain index的三段定义 :即保留x的信息 (最大化互信息 β ) 、z对标签y的敏感度(最大化互信息 )、z和β的条件独立(互信息 β )。
学到了啥有意思的domain index: 既然有了理论保障,那么接下来我们可以看一下,如果按照上面的方法训练模型,我们能推断出来什么样的global domain index呢?我们用的第一个数据集是之前CIDA用的Circle数据集。这个数据集包含了30个domain,如下图所示。 左下图 是用颜色标记了domain index,我们可以看到颜色是渐变的,也就是说ground-truth的domain index是从1到30。绿色框里表示的是6个source domain,其他部分为target domain。 右下图 是用蓝色和红色标记了标签(label),可以看出来这是个二分类的数据集,蓝色表示正例,红色表示负例。
下面的图展示了我们的VDI学习到的domain index 和ground-truth domain index的对比。可以看到,我们学到的domain index和真正的domain index是 高度吻合 的,correlation达到了0.97。有趣的是,跟CIDA不一样,我们在训练VDI过程中,并没有用到任何的domain index,所有的domain index都是VDI模型自己以 无监督 的方式推断出来的。
除了Circle这个toy dataset,我们还测试了现实的数据集。比如之前我们在GRDA构建的TPT-48温度预测数据集。这个数据集有美国大陆48个州的每月气温。这里的任务(task)是,根据前6个月的气温,预测后6个月的气温(如下图左边)。我们把一部分州的数据作为source domain(如下图黑底白字的州),然后把其他州作为target domain(如下图白底黑字的州)。我们把target domain分成3个层级, level-1、level-2、和level-3的target domain分别表示 离source domain最近、次近、和最远的target domain。
有意思的是,即使在 无监督 (未知正确的domain index)的情况下,我们的VDI依然能够学出有意义的domain index。比如 下图左边 ,我们画出来VDI学出来的2维的domain index β。下面每个点的 坐标位置 表示的是我们VDI学到的2维domain index,而 颜色 则表示对应的domain(州)真实的纬度。我们可以看到,我们domain index的第一维(横轴)和真实的每个州的纬度高度吻合。比如纽约(NY)和新泽西(NJ)纬度距离比较近,而且都在比较北边(如下面的右图),那么对应的,他们的domain index也很接近。相反,佛罗里达(FL)离NY和NJ的纬度距离都比较远,对应地,它的domain index也离NY和NJ比较远。
另一个真实数据集是CompCar,CompCar里包含了各种车的照片,这些照片有2维真实的domain index, 拍照的角度 (比如正面照、侧面照、后面照等等)以及 出厂年份 (比如2009)。类似地,我们把VDI学到的2维domain index画到下图。下面每个点的坐标位置表示的是我们VDI学到的domain index,而 颜色 则表示真实的拍照角度(左图)和出厂年份(右图)。可以看到,即使是在 无监督 的情况下,我们学出来的domain index依然和真实的拍照角度和出厂年份高度相关。
提高domain adaptation准确率: 当然除了能学出有意思的domain index,VDI自然可以利用这些学到的domain index,来提高domain adaptation的准确度。下面的表格是TPT-48的温度预测误差(MSE)对比。我们可以看到VDI几乎在所有层级(level)的target domain都能有准确率的提高。
熟悉的同学可能可以看出来,这个VDI其实有点像是我们ICML’20的”Continuously Indexed Domain Adaptation”(CIDA)的逆问题,同时也可以看成是和CIDA这类算法的互补的问题。CIDA是想通过已知的domain index来提高连续域adaptation的准确度,而VDI则解决了一个更general的问题,也就是当这个domain index未知的时候,应该如何去推断出来。而且一旦推断出来domain index,我们就可以放心地继续使用CIDA来实现连续域(甚至是传统的离散域)的adaptation准确率的提升了。
Paper: https://arxiv.org/pdf/2302.02561.pdf or http://wanghao.in/paper/ICLR23_VDI.pdf OpenReview: https://openreview.net/forum?id=pxStyaf2oJ5 YouTube Video: https://www.youtube.com/watch?v=xARD4VG19ec Bilibili Video: https://www.bilibili.com/video/BV13N411w734/?share_source=copy_web GitHub Link: https://github.com/wang-ML-Lab/VDI
作者:王灏 https://www.zhihu.com/question/557295083/answer/2977965268
Illustration by IconScout Store from IconScout
-The End-
扫码观看!
本周上新!
“AI技术流”原创投稿计划
TechBeat是由将门创投建立的AI学习社区(www.techbeat.net)。社区上线480+期talk视频,2400+篇技术干货文章,方向覆盖CV/NLP/ML/Robotis等;每月定期举办顶会及其他线上交流活动,不定期举办技术人线下聚会交流活动。我们正在努力成为AI人才喜爱的高质量、知识型交流平台,希望为AI人才打造更专业的服务和体验,加速并陪伴其成长。
投稿内容
// 最新技术解读/系统性知识分享 //
// 前沿资讯解说/心得经历讲述 //
投稿须知
稿件需要为原创文章,并标明作者信息。
我们会选择部分在深度技术解析及科研心得方向,对用户启发更大的文章,做原创性内容奖励
投稿方式
发送邮件到
chenhongyuan@thejiangmen.com
或添加工作人员微信(chemn493)投稿,沟通投稿详情;还可以关注“将门创投”公众号,后台回复“投稿”二字,获得投稿说明。
关于我“门”
将门是一家以专注于发掘、加速及投资技术驱动型创业公司的新型创投机构,旗下涵盖将门创新服务、将门技术社群以及将门创投基金。
将门成立于2015年底,创始团队由微软创投在中国的创始团队原班人马构建而成,曾为微软优选和深度孵化了126家创新的技术型创业公司。
如果您是技术领域的初创企业,不仅想获得投资,还希望获得一系列持续性、有价值的投后服务,欢迎发送或者推荐项目给我“门”:
bp@thejiangmen.com
点击右上角,把文章分享到朋友圈
⤵一键送你进入TechBeat快乐星球
特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.