网易首页 > 网易号 > 正文 申请入驻

ICLR'23 Spotlight | 域适应中的域索引:定义、方法、理论和可解释性

0
分享至

来给大家介绍一下我们被接收为ICLR Spotlight的新工作。这个work从2021年春开始一直做到2022年秋,中间克服了许多技术障碍,没想到第一次投稿就好评如潮(分数8886),也恭喜子昊的坚持得到回报。这篇工作的核心贡献在于,正式定义了domain adaptation中的域索引(domain index),精心设计了推断(infer)domain index的算法(variational domain indexing, 即VDI),并且证明了我们的算法可以推断出最优的domain index。由于推断出来的domain index带来的free lunch,domain adaptation的性能也得到了提高。

论文链接: http://wanghao.in/paper/ICLR23_VDI.pdf 代码链接: https://github.com/Wang-ML-Lab/VDI

什么是domain index:domain index的说法最早在我们的ICML 2020论文“Continuously Indexed Domain Adaptation”(CIDA)中提出(有兴趣的看官欢迎移步我们讲CIDA的知乎帖子)。最直观的例子就是在医疗应用里面,不同年龄的人可以看成是不同的domain,而这个“年龄”其实就是domain的一个索引(index),也就是我们说的domain index(域索引)。如下图。

有意思的是,domain index其实是一个连续的概念,所以自然而然地包含了domain的远近信息。比如上面说的“年龄”可以作为一个一维的domain index,年龄18和19距离很近,而18和80却距离很远。我们之前在CIDA(大致的CIDA模型如下图)上的实验发现,如果已知这个domain index,我们可以很好地做到连续域上的domain adaptation,从而大幅提高准确率。

比如把模型从年龄0 ~ 20的病人(source domains),adapt到年龄20 ~ 80的病人(target domains),或者从年龄0 ~ 15以及50 ~ 80的病人(source domains),adapt到年龄15 ~ 50的病人(target domains),如下图。

那么问题来了,如果这个domain index是未知的,咋办? 最理想的情况当然是,我们能够把这个domain index作为隐变量(latent variable),通过 无监督 (unsupervised)的方式把它推断(infer)出来。如果这个方案可行,我们就免费拿到了一个重要的额外的信息,从而既可以提高domain adaptation的 准确率 ,又能提高它的 可解释性。

Domain Index的正式定义 : 在推断domain index前,我们要先定义清楚,什么才算是domain index,然后才能设计推断它的方法。

这里我们首先引入了两种domain index, local domain index(用u表示)和global domain index(用β表示)。 我们规定,虽然同一个domain里的不同数据点(data point)可以有不同的local domain index,但是同一个domain里的所有数据的global domain index必须是是 相同的 。也就是说,local domain index是一个instance-level的变量,而global domain index是一个domain-level的变量。下面的图是一个具体的例子,展示了global domain index β、local domain index u、数据x之间的关系。

那么符合什么条件的u和beta才能被叫做domain index呢?我们定义了三个条件(这里x表示数据,y表示标签,z表示x经过encoder后得到的encoding):

  1. z和β的条件独立:Encoding z和global domain index β是条件独立的。换句话说,他们的互信息 β 必须是0。

  2. 保留x的信息:Encoding z,global domain index β,和local domain index u这三组变量 比如尽可能地保留数据x的信息。换句话说,他们的互信息 β 必须达到最大

  3. z对标签y的敏感度:Encoding z要尽可能保留标签y的信息(这样才能提高预测y的准确率)。这意味着他们的互信息 必须达到最大 。

如果 β 和 u 满足上述三个条件,我们就把它们分别称为global domain index和local domain index。这三个条件可以用下面的数学公式表示:

方法的整体思路 : 定义完domain index后,下一个问题自然就是,如何能在无监督(完全不知道domain index)的情况下,有效地推断出符合上面三段定义的domain index β 和 u 呢?这时,就要请出adversarial Bayesian deep learning model(对Bayesian deep learning感兴趣的同学可以看看我们之前的帖子)来解决这个问题。

在Bayesian deep learning里面,或者更加传统的probabilistic graphical model里面,我们会分两步走:

  • 第一步 是首先假设一下 已知变量(observed variable) 是如何从 隐变量(latent variable,即未知的变量) 一步步生成的。我们一般把这个叫做 生成过程 (generative process)。

  • 然后 第二步 ,就是通过贝叶斯推断(Bayesian inference)的方式来根据已知变量来倒推隐变量。

在我们目前的问题里,数据x以及标签y都是已知变量,而我们的encoding z以及domain index β和u则是 隐变量 。那么很自然,我们的目的就是已知各个domain里的 数据x以及标签y ,然后想推断出encoding z以及domain index β和u。注意,在domain adaptation里面,只有source domain才有已知的标签y。target domain只有数据x。

生成过程: 根据这个整体思路,我们就首先假设一下各个变量 生成过程 (如下图左边):

对于每个 domain :

从高斯分布 β α 中生成一个global domain index β ,

对于domain k中的每个数据点 ,

从高斯分布 β 中生成一个localdomain index ,

从高斯分布 中生成数据

从高斯分布 β 中生成encoding

从分 布 生成标签

用变分分布估计后验概率: 用变分分布估计后验概率:有了这个生成过程,我们就可以开始思考如何推断(infer) 出每个数据 对应的 encoding 及其domain index β 和 u 。我们首先会先构造一些变分分布(variational distribution),通过学些这些变分分布来推断 、β 和 u 。比如,如果我们会学会了变分分布 ,那么,给定一个数据 ,我们就能根据 得到local domain index 了。

在我们的方法里面我们一共定义了3个变分分布: , β ,和 β 。这里对应着上图的右边。在这几个分布里面,比较关键的是分布 β ,它会对同一个domain下所有数据的local domain index做一个聚合(aggregation),来推断这个domain的global domain index。注意每个数据都有自己的不同的local domain index,而同一个domain里的所有数据只共享同一个global domain index。这里的 {u}的大括号表示的是同一个domain里所有data对应的所有local domain index u组成的集合。在推断global domain index时,我们还在u的集合上应用了optimal transport,有兴趣的同学可以看下论文原文的细节。

Evidence Lower Bound(ELBO):接下来就是用ELBO把5个生成分布 β α , β , p(x_i∣u_i) , β , 和3个变分分布 , β , β 串成下面的目标函数:

从变分(variational inference)的角度,最大化上面的ELBO,等价于在寻找最优的变分分布 , β , β 来估计 , β , 的实分布。

上面的目标函数可能有点冗长难懂,直接看 下图 可能会好些。直观地讲,我们可以把优化这个ELBO,看成 学习很多子网络 来对输入数据x进行 编码 (encode)和 重构 (reconstruct)的过程,关键在于,在这个编码和重构的过程中,需要 聪明地把domain index β 和 u 建模进去。

对贝叶斯推断(Bayesian Inference)熟悉的同学可能已经发现了,这个其实就是我们之前说的(广义的)贝叶斯深度学习(Bayesian Deep Learning)的思路:用 深度模块 (deep component)来处理高维信号x(比如图片),然后用 概率图模块 (graphical component)来表示各个随机变量之间的条件概率关系(比如图片x及其对应的encoding z和domain index β、u的关系)。

回到Domain Index的三段定义:讲到这里,眼尖的同学可能会发现,虽然最大化这个ELBO目标函数确实可能可以符合前面说的domain index的三个要求中的后两个,即保留x的信息(最大化互信息 β ) 和z对标签y的敏感度(最大化互 信息 ) , 但是却忽略了第一个要求,即z和β的条件独立(互信息 β )。

为了满足第一个要求,我们需要借鉴 对抗域迁移 (adversarial domain adaptation)的思想,在上图的基础上,再加上一个discriminator,然后对抗地(adversarially)训练整个网络,使得encoder能把不同domain的x映射到一个encoding空间,然后让这个discriminator无法从他们的encoding z来分辨出数据是来自于哪个domain的。我们把这个操作叫做encoding的对齐(alignment),即把不同的domain的encoding分布对齐起来,让他们互相重叠,这样就可以方便不同domain共享一个predictor了(比如分类器或者回归器)。加上discriminator之后的神经网络架构如下:。

最终的目标函数: 相应地,我们最终的目标函数也从一个简单的优化问题(最大化ELBO)变成了一个minimax game:

理论保障 : 有趣的是,我们可以严格地证明,上面的目标函数的全局最优点正好就可以同时满足我们对 domain index的三段定义 :即保留x的信息 (最大化互信息 β ) 、z对标签y的敏感度(最大化互信息 )、z和β的条件独立(互信息 β )。

学到了啥有意思的domain index: 既然有了理论保障,那么接下来我们可以看一下,如果按照上面的方法训练模型,我们能推断出来什么样的global domain index呢?我们用的第一个数据集是之前CIDA用的Circle数据集。这个数据集包含了30个domain,如下图所示。 左下图 是用颜色标记了domain index,我们可以看到颜色是渐变的,也就是说ground-truth的domain index是从1到30。绿色框里表示的是6个source domain,其他部分为target domain。 右下图 是用蓝色和红色标记了标签(label),可以看出来这是个二分类的数据集,蓝色表示正例,红色表示负例。

下面的图展示了我们的VDI学习到的domain index 和ground-truth domain index的对比。可以看到,我们学到的domain index和真正的domain index是 高度吻合 的,correlation达到了0.97。有趣的是,跟CIDA不一样,我们在训练VDI过程中,并没有用到任何的domain index,所有的domain index都是VDI模型自己以 无监督 的方式推断出来的。

除了Circle这个toy dataset,我们还测试了现实的数据集。比如之前我们在GRDA构建的TPT-48温度预测数据集。这个数据集有美国大陆48个州的每月气温。这里的任务(task)是,根据前6个月的气温,预测后6个月的气温(如下图左边)。我们把一部分州的数据作为source domain(如下图黑底白字的州),然后把其他州作为target domain(如下图白底黑字的州)。我们把target domain分成3个层级, level-1、level-2、和level-3的target domain分别表示 离source domain最近、次近、最远的target domain

有意思的是,即使在 无监督 (未知正确的domain index)的情况下,我们的VDI依然能够学出有意义的domain index。比如 下图左边 ,我们画出来VDI学出来的2维的domain index β。下面每个点的 坐标位置 表示的是我们VDI学到的2维domain index,而 颜色 则表示对应的domain(州)真实的纬度。我们可以看到,我们domain index的第一维(横轴)和真实的每个州的纬度高度吻合。比如纽约(NY)和新泽西(NJ)纬度距离比较近,而且都在比较北边(如下面的右图),那么对应的,他们的domain index也很接近。相反,佛罗里达(FL)离NY和NJ的纬度距离都比较远,对应地,它的domain index也离NY和NJ比较远。

另一个真实数据集是CompCar,CompCar里包含了各种车的照片,这些照片有2维真实的domain index, 拍照的角度 (比如正面照、侧面照、后面照等等)以及 出厂年份 (比如2009)。类似地,我们把VDI学到的2维domain index画到下图。下面每个点的坐标位置表示的是我们VDI学到的domain index,而 颜色 则表示真实的拍照角度(左图)和出厂年份(右图)。可以看到,即使是在 无监督 的情况下,我们学出来的domain index依然和真实的拍照角度和出厂年份高度相关。

提高domain adaptation准确率: 当然除了能学出有意思的domain index,VDI自然可以利用这些学到的domain index,来提高domain adaptation的准确度。下面的表格是TPT-48的温度预测误差(MSE)对比。我们可以看到VDI几乎在所有层级(level)的target domain都能有准确率的提高。

熟悉的同学可能可以看出来,这个VDI其实有点像是我们ICML’20的”Continuously Indexed Domain Adaptation”(CIDA)的逆问题,同时也可以看成是和CIDA这类算法的互补的问题。CIDA是想通过已知的domain index来提高连续域adaptation的准确度,而VDI则解决了一个更general的问题,也就是当这个domain index未知的时候,应该如何去推断出来。而且一旦推断出来domain index,我们就可以放心地继续使用CIDA来实现连续域(甚至是传统的离散域)的adaptation准确率的提升了。

Paper: https://arxiv.org/pdf/2302.02561.pdf or http://wanghao.in/paper/ICLR23_VDI.pdf OpenReview: https://openreview.net/forum?id=pxStyaf2oJ5 YouTube Video: https://www.youtube.com/watch?v=xARD4VG19ec Bilibili Video: https://www.bilibili.com/video/BV13N411w734/?share_source=copy_web GitHub Link: https://github.com/wang-ML-Lab/VDI

作者:王灏 https://www.zhihu.com/question/557295083/answer/2977965268 ‍

Illustration by IconScout Store from IconScout

-The End-

扫码观看!

本周上新!

“AI技术流”原创投稿计划

TechBeat是由将门创投建立的AI学习社区(www.techbeat.net)。社区上线480+期talk视频,2400+篇技术干货文章,方向覆盖CV/NLP/ML/Robotis等;每月定期举办顶会及其他线上交流活动,不定期举办技术人线下聚会交流活动。我们正在努力成为AI人才喜爱的高质量、知识型交流平台,希望为AI人才打造更专业的服务和体验,加速并陪伴其成长。

投稿内容

// 最新技术解读/系统性知识分享 //

// 前沿资讯解说/心得经历讲述 //

投稿须知

稿件需要为原创文章,并标明作者信息。

我们会选择部分在深度技术解析及科研心得方向,对用户启发更大的文章,做原创性内容奖励

投稿方式

发送邮件到

chenhongyuan@thejiangmen.com

或添加工作人员微信(chemn493)投稿,沟通投稿详情;还可以关注“将门创投”公众号,后台回复“投稿”二字,获得投稿说明。

关于我“门”

将门是一家以专注于发掘、加速及投资技术驱动型创业公司的新型创投机构,旗下涵盖将门创新服务将门技术社群以及将门创投基金

将门成立于2015年底,创始团队由微软创投在中国的创始团队原班人马构建而成,曾为微软优选和深度孵化了126家创新的技术型创业公司。

如果您是技术领域的初创企业,不仅想获得投资,还希望获得一系列持续性、有价值的投后服务,欢迎发送或者推荐项目给我“门”:

bp@thejiangmen.com

点击右上角,把文章分享到朋友圈

⤵一键送你进入TechBeat快乐星球

特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。

Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.

相关推荐
热点推荐
“将是场非常不幸的峰会”,外媒担忧:G7几乎全部人人“自身难保”

“将是场非常不幸的峰会”,外媒担忧:G7几乎全部人人“自身难保”

观察者网
2024-06-13 21:56:15
多看看美女有益身心健康

多看看美女有益身心健康

室内设计师阿喇
2024-06-13 15:37:16
最高赔偿超2w/㎡!成都这些人发财了!

最高赔偿超2w/㎡!成都这些人发财了!

楼市灭霸
2024-06-13 09:29:51
你无意中摸过什么不该摸的东西?网友:对视的那一刻,我俩都懵了

你无意中摸过什么不该摸的东西?网友:对视的那一刻,我俩都懵了

文雅笔墨
2024-06-13 20:51:06
17.59万起!深蓝G318上市即爆火:60分钟订单破3000台

17.59万起!深蓝G318上市即爆火:60分钟订单破3000台

快科技
2024-06-13 22:13:25
中国外贸发生大变化!高中低端市场全部失守,依赖我国消费者?

中国外贸发生大变化!高中低端市场全部失守,依赖我国消费者?

科普六点半
2024-06-04 09:27:58
厦门有大动作,台军一把手脸色都变了,放豪言:我要找大陆说清楚

厦门有大动作,台军一把手脸色都变了,放豪言:我要找大陆说清楚

千里持剑
2024-06-12 15:00:30
我突然理解,为啥上海女孩子真不会精致土

我突然理解,为啥上海女孩子真不会精致土

小虎新车推荐员
2024-06-13 17:30:33
基德:东契奇必须能防守 但也要明白我们会在他被攻击时保护他

基德:东契奇必须能防守 但也要明白我们会在他被攻击时保护他

直播吧
2024-06-13 16:33:13
阶级斗争熄灭论实现了,胡汉三回来了

阶级斗争熄灭论实现了,胡汉三回来了

雪中风车
2024-06-10 18:09:17
日本顶级财阀的联姻:丰田家族第五代继承人,迎娶女星“小赫本”

日本顶级财阀的联姻:丰田家族第五代继承人,迎娶女星“小赫本”

回京历史梦
2024-05-28 12:48:59
通报批评的好!吉林银行取2万需派出所同意后续,律师发声:违法

通报批评的好!吉林银行取2万需派出所同意后续,律师发声:违法

户外阿崭
2024-06-13 14:00:08
印度真的有网上说的那么脏乱差嘛?网友评论一针见血,给我笑麻了

印度真的有网上说的那么脏乱差嘛?网友评论一针见血,给我笑麻了

开玩笑的水母
2024-06-11 17:55:57
A股:又有消息来了,A股会继续下跌还是开启大涨?

A股:又有消息来了,A股会继续下跌还是开启大涨?

财经大拿
2024-06-14 03:30:02
庾澄庆这是怎么了?突然一下就老了,而且越来越像于谦

庾澄庆这是怎么了?突然一下就老了,而且越来越像于谦

盛华名阁汇
2024-06-12 14:00:40
IMF最新评估发现:人民币国际化出现停滞

IMF最新评估发现:人民币国际化出现停滞

长平投研
2024-06-13 22:06:40
孙楠冲上热搜!身高178,体重130斤引爆笑,坐韩红对面瘦到认不出

孙楠冲上热搜!身高178,体重130斤引爆笑,坐韩红对面瘦到认不出

山野下
2024-06-12 08:33:54
两性羞羞:添女友这里,她会嗨到不行

两性羞羞:添女友这里,她会嗨到不行

坟头长草
2024-05-30 16:23:58
《玫瑰的故事》直到黄亦玫遇到方协文的床,她都不知庄国栋错在哪

《玫瑰的故事》直到黄亦玫遇到方协文的床,她都不知庄国栋错在哪

小邵说剧
2024-06-13 21:00:42
60岁的夫妻,一周同房几次比较好?60岁绝经大姐一周一次,如何?

60岁的夫妻,一周同房几次比较好?60岁绝经大姐一周一次,如何?

39健康网
2024-05-19 23:20:03
2024-06-14 04:16:49
将门创投
将门创投
加速及投资技术驱动型初创企业
1821文章数 585关注度
往期回顾 全部

科技要闻

小红书员工仅1/5工龄满2年 32岁就不让进了

头条要闻

上海楼市新政后有业主熬夜卖房:比之前最低价高360万

头条要闻

上海楼市新政后有业主熬夜卖房:比之前最低价高360万

体育要闻

乔丹最想单挑的男人走了

娱乐要闻

森林北报案,称和汪峰的感情遭受压力

财经要闻

私募大佬孙强:中国为什么缺少耐心资本

汽车要闻

升级8155芯片 新款卡罗拉锐放售12.98-18.48万

态度原创

家居
数码
教育
时尚
旅游

家居要闻

大城小室 质朴自然的心灵居所

数码要闻

1999元起!飞米MINI 3无人机小米有品开售:4K录制、32分钟续航

教育要闻

TTS新传论文带读:粉过恋爱兄妹的李龙宇是我的赛博案底…

受法律保护的可颂,究竟有多好吃!?

旅游要闻

山西文旅厅厅长与董宇辉拉家常:中午回家吃了饭

无障碍浏览 进入关怀版