Variational Distillation of Diffusion Policies into Mixture of Experts
扩散策略的变分蒸馏至专家混合模型
https://openreview.net/pdf?id=iiYadgKHwo
摘要
本工作提出了变分扩散蒸馏(Variational Diffusion Distillation, VDD),这是一种通过变分推理将去噪扩散策略蒸馏到专家混合模型(Mixture of Experts, MoE)中的新方法。扩散模型因其在准确学习和表示复杂、多模态分布方面的卓越能力,已成为生成模型领域的最先进方法。这种能力使扩散模型能够复制人类行为中固有的多样性,因此在行为学习(如从人类示范中学习,LfD)中成为首选模型。然而,扩散模型也存在一些缺点,包括似然函数不可计算,以及由于迭代采样过程导致的较长推理时间。特别是推理时间,对机器人控制等实时应用构成了重大挑战。相比之下,MoE模型能有效解决上述问题,同时保留表示复杂分布的能力,但MoE模型众所周知难以训练。VDD是首个将预训练扩散模型蒸馏到MoE模型的方法,因此结合了扩散模型的表达能力与混合模型的优势。具体而言,VDD利用了变分目标函数的一个可分解上界,使得每个专家可以分别训练,从而形成一种鲁棒的MoE优化方案。VDD在九个复杂的行为学习任务中表现出:i)能够准确蒸馏扩散模型所学习的复杂分布;ii)优于现有的最先进蒸馏方法;iii)超越传统的MoE训练方法。代码和视频可在 https://intuitive-robots.github.io/vdd-website 获取。
1 引言
扩散模型[1–4]因其在诸如真实感图像生成[5–8]等多个领域取得的巨大成功而受到越来越多的关注。最近,扩散模型在从人类示范中学习(LfD)方面也展现出良好前景[9–13]。LfD的一个特别具有挑战性的方面是,由于人类行为本身具有高度多样性,导致数据分布具有高方差和多模态特性[14]。由于扩散模型具备泛化并表示复杂、多模态分布的能力,因此特别适合作为LfD中的策略表示方法。然而,扩散模型存在若干缺点,例如推理时间长以及似然计算不可行。为了获得高质量的样本,需要进行多次扩散步骤,从而导致较长的推理时间,限制了其在机器人控制等需要高频决策的实时应用中的使用。此外,扩散模型难以获得诸如精确似然等重要统计特性,这为通过策略梯度或最大熵强化学习等成熟强化学习(RL)方法进行事后优化(如微调)带来了显著挑战。
一种被广泛研究且能有效解决这些问题的方法是专家混合模型(Mixture of Experts, MoE)。在推理过程中,MoE首先选择一个专家,然后对该专家执行前向传播。这种分层结构提供了快速而简单的采样过程、可计算的似然以及表示多模态分布的能力。这些特性使MoE成为表示复杂、多模态行为的合适策略模型。然而,训练专家混合模型(MoE)通常困难且不稳定[15]。常用的极大似然目标函数可能导致不良行为,例如模式平均(mode-averaging),即模型无法准确表示某些模态。不过,这一局限性已被近期使用替代目标函数(如反向KL散度)的方法所缓解,这类方法不会出现模式平均问题[14, 16]。
为了同时获得两种模型的优势——即使用扩散模型学习高度精确的生成模型,同时利用专家混合模型获得简单且可计算的模型——本文提出了变分扩散蒸馏(VDD),这是一种将扩散模型蒸馏到MoE的新方法。从变分推理目标[17, 18]出发,我们推导出一个下界,将目标函数分解为各个专家独立的目标,从而形成一种鲁棒的优化方案。每个专家的目标函数巧妙地利用了预训练得分函数的梯度,使得MoE能够受益于扩散模型的特性。最终得到的MoE策略在性能上与扩散模型相当,覆盖相同的模态,同时具备可解释性、推理速度更快以及似然可计算的优点。该最终策略可直接供用户用于事后分析,或针对更具体场景进行快速微调。图1展示了VDD模型的高层架构及其与扩散策略的关系。VDD在九个复杂的行为学习任务上进行了全面评估,验证了上述特性。此外,本文还观察到,单步连续扩散模型已能表现良好,这一发现此前未在相关工作中讨论。
综上所述,本文提出了VDD——一种通过提出变分目标函数将扩散模型蒸馏到MoE的新方法,该目标函数可实现各专家的独立且鲁棒的更新,有效利用了预训练扩散模型。在九个复杂行为学习任务上的全面实验评估表明,VDD:i)能够准确蒸馏复杂分布;ii)优于现有的最先进蒸馏方法;iii)超越传统的MoE训练方法。
2 相关工作
用于行为学习的扩散模型。扩散模型已被用于获取复杂行为,以在各种学习框架中解决复杂任务。这些工作中的大多数使用离线强化学习[19–24]或模仿学习[9, 12, 11, 10, 25]来训练扩散策略。相比之下,VDD并非直接从数据中优化策略,而是将扩散模型蒸馏为一个专家混合(MoE)策略,以克服基于扩散的策略所存在的推理时间长或似然不可计算等缺点。
用于行为学习的专家混合模型(MoE)。MoE模型已被广泛研究,具有可计算的似然,并能表示多模态性,因此在许多领域中成为流行选择,例如在模仿学习[14, 26–31, 16, 13]、强化学习[32–38]和运动生成[39]中用于获取复杂行为。尽管VDD也使用MoE模型,但其行为是通过变分目标从预训练模型中蒸馏而来,而非从零开始训练。实证评估表明,VDD稳定的训练过程相较于常见的MoE学习技术能够带来性能提升。
从扩散模型进行的知识蒸馏。从扩散模型进行知识蒸馏已在多个研究领域中被探索。例如,在文本到3D建模中,提出了一种无需任何3D数据即可训练基于NeRF的文本到3D模型的方法:将3D场景映射到2D图像,并利用文本到2D的扩散模型[7]。该工作提出最小化一种称为得分蒸馏采样(Score Distillation Sampling, SDS)的损失,该损失受概率密度蒸馏[40]启发,激励3D模型向扩散模型得分函数指示的更高密度区域更新。为了克服使用SDS损失时出现的过度平滑和多样性不足等问题,变分得分蒸馏(Variational Score Distillation, VSD)将3D场景视为随机变量,并优化这些场景上的分布,使得投影得到的2D图像与2D扩散模型对齐[8]。在类似背景下,文献[41]提出将训练好的扩散模型蒸馏到另一个扩散模型中,同时逐步减少扩散步数。然而,尽管扩散步数大幅减少,但并未实现完全的蒸馏,即未实现像VDD那样的单步推理。此外,所得模型仍存在扩散模型的相同缺点,如似然不可计算。相比之下,在一致性蒸馏(Consistency Distillation, CD)中,扩散模型被蒸馏为一致性模型(Consistency Models, CM)[42–45],使得可以从噪声一步生成数据。然而,单步数据生成通常导致样本质量较低,因此需要在迭代生成和单步生成之间根据期望结果进行权衡。与CM类似,VDD也执行单步数据生成,但将预训练扩散模型蒸馏到MoE中,而MoE具有可计算的似然且推理时间高效。实验评估显示了VDD相对于CM的优势。Diff-Instruct[46]提出了一种两步框架,将扩散模型蒸馏到隐式生成模型中,而VDD考虑的是显式生成模型,其模型密度可直接评估。此外,Diff-Instruct需要训练一个辅助扩散模型,而VDD仅优化单一模型。得分正则化策略优化(Score Regularized Policy Optimization, SRPO)[47]也利用扩散行为策略来正则化基于离线强化学习的目标。然而,与SRPO不同,VDD学习的是MoE策略而非单模态高斯策略,并且显式地蒸馏扩散模型,而不是在优化过程中仅将其作为指导。此外,VDD在模仿学习中训练MoE策略,而不是像离线强化学习那样使用带有奖励标签的数据。一项并行工作EM蒸馏(EM-Distillation, EMD)[48]提出了一种基于模式覆盖型前向KL散度推导出的EM风格蒸馏目标。相比之下,VDD提出了一种基于模式寻求型反向KL散度的EM风格目标,但通过引入多个专家来鼓励模式覆盖行为。
3 预备知识
在这里,我们介绍去噪扩散和专家混合策略的符号表示和基础理论。在整个工作中,我们假设可以访问行为策略 π* 的样本以及相应的状态分布 μ,即 a ∼ π*(·|s) 和 s ∼ μ(·),分别对应。
去噪扩散策略。去噪扩散策略采用扩散过程将数据平滑地转化为噪声。对于给定的状态 s',扩散过程被建模为随机微分方程(SDE)[3]。
4 去噪扩散策略的变分蒸馏
4.1 可扩展的变分推理用于去噪扩散策略蒸馏
4.2 通过专家混合进行变分推理
为了蒸馏扩散模型所学习的多模态分布,我们需要比通常在近似变分推理(amortized VI)中使用的条件对角高斯分布更复杂的分布族。因此,我们将使用高斯混合专家(Gaussian mixture of experts)。为此,我们构建一个可分解为每个专家单独目标的 J 的上界,从而允许对每个专家进行单独的重新参数化,避免了需要对整个 MoE 进行重新参数化的技术 [60, 61]。通过利用 KL 散度的链式法则 [62, 63, 15],可以得到上界 U(φ, q̃),即,
4.3 选择扩散时间步
5 实验
我们通过蒸馏两种类型的扩散模型进行了模仿学习实验:方差保持型(Variance Preserving, VP)[2, 12] 和方差爆炸型(Variance Exploding, VE)[65, 4]。我们选择DDPM作为VP的代表,BESO作为VE的代表。我们采用了文献[9]和[13]中关于采样器和去噪步数的选择。关于不同去噪步数的教师模型的额外评估见附录F。在实验中,VP-1和VE-1表示在推理过程中仅执行各自扩散模型一次去噪步骤的结果。VDD-VP和VDD-VE表示蒸馏得到的VDD方法的结果。
此外,我们将当前最先进的方法一致性蒸馏(Consistency Distillation, CD)[42] 和一致性轨迹模型(Consistency Trajectory Model, CTM)[44] 作为基线,用于比较VDD在蒸馏任务中的性能。对于CD和CTM,我们按照原始论文的做法,从VE模型进行蒸馏。对于CTM,我们借鉴了一致性策略(Consistency Policy)[45] 中的实现和设计选择,这些选择专门针对行为学习进行了优化。
我们还将VDD与MoE学习的基线方法进行比较,包括广泛使用的期望最大化(Expectation-Maximization, EM)[66] 方法(作为基于最大似然目标的代表),以及最近提出的最先进方法信息最大化课程(Information Maximizing Curriculum, IMC),作为基于反向KL散度目标的代表。为了使这些基线更具竞争力,我们采用图5中描述的架构对它们进行了扩展,并将扩展后的方法分别命名为EM-GPT和IMC-GPT。
为了公平比较,我们对所有蒸馏方法均使用相同的扩散模型作为源模型,并在随机种子0上进行训练。为了进行具有统计显著性的比较,所有方法均在4个不同的随机种子上运行,评估结果中报告了均值和标准差。有关基线方法的实现细节和超参数选择的详细说明见附录D和E。
评估结构如下:首先,我们在两个已建立的数据集上证明,VDD能够与当前最先进的扩散蒸馏方法以及原始扩散模型取得具有竞争力的性能。接着,我们在一个最近提出的、包含人类示范的具有挑战性的基准任务上进行实验,VDD在此任务上优于现有的扩散蒸馏方法以及最先进的MoE学习方法。随后,我们强调了VDD更快的推理速度。接下来,通过一系列消融研究,验证了VDD关键算法特性的重要性。最后,我们提供了可视化结果,以更深入地理解我们的方法。
5.1 在模仿学习数据集中的竞争性蒸馏性能
我们首先在两个广泛认可的模仿学习数据集上展示VDD的有效性:Relay Kitchen [67] 和 XArm Block Push [68]。这些环境的详细描述见附录C。为了确保公平比较,我们遵循文献[9]中所述的相同评估流程。表1a展示了Relay Kitchen环境的奖励得分和XArm Block Push任务的成功率,结果基于100次环境运行的均值和标准差。结果表明,VDD在这两个任务中的表现与一致性蒸馏(CD)相当,在积木推动数据集上略优。一个额外的有趣发现是,仅使用一步去噪的BESO模型(VE-1)在这两个任务中已表现出较强的基线性能,因为在两种情况下,原始模型的表现均优于所有蒸馏方法的结果。我们认为这一有趣现象的原因在于,Relay Kitchen和XArm Block Push任务相对容易解决,且未提供多样化、多模态的数据分布。因此,我们进一步在最近发布的一个数据集D3IL [13] 上评估这些方法,该数据集专门用于复杂的机器人模仿学习任务,并提供了任务熵(task entropy)的测量指标。
5.2 在保持多样化行为的同时复现扩散模型性能
D3IL基准测试为多个具有挑战性的机器人操作任务提供了人类示范,重点评估方法在任务成功率和行为多样性(即解决同一任务的不同行为方式)两方面的能力。该基准为每个任务提供了一个多样性度量指标,称为任务熵(task entropy)。任务熵是一个介于0到1之间的标量值:0表示模型仅学会了一种解决任务的方式,1表示模型已覆盖了人类示范中的所有技能。环境的详细描述以及任务熵的计算方法见附录C。蒸馏后策略的任务成功率如表1a所示。结果显示,在7个任务中的6个任务上(除“对齐”任务外),VDD优于一致性蒸馏方法以及原始模型的单步变体。然而,在“对齐”任务中,VDD取得了更高的任务熵,表明其学习到了更多样化的行为。任务熵的结果如表1b所示。结果表明,在7个任务中的4个任务上,VDD的任务熵高于一致性模型(CD、CTM)和单步扩散模型(VP-1、VE-1),这说明我们的方法能够从扩散策略中成功复现高质量且多样化的行为。
之前的评估表明,VDD能够有效地将扩散模型蒸馏到专家混合模型(MoE)中,同时保持原有的性能和行为多样性。在本节中,我们通过将VDD与EM-GPT和IMC-GPT进行比较,讨论使用VDD的必要性,而非直接从零开始训练MoE模型。这两种方法都是从头开始训练MoE模型,但在目标函数上有所不同:EM基于广为人知的最大似然目标,而IMC基于反向KL散度目标。表2中的结果显示,在大多数任务中,无论是在任务成功率还是任务熵方面,VDD均持续优于EM-GPT和IMC-GPT。我们认为性能提升的原因在于:利用了扩散模型的泛化能力,以及我们所提出的可分解下界(公式(11))所带来的稳定更新。
5.4 蒸馏后MoE的快速推理
MoE模型的推理不需要迭代去噪过程,因此采样速度更快。我们在基于状态的推动任务和基于图像的堆叠任务上,将VDD与DDPM和BESO的推理时间进行比较,并在表3中报告了200次预测的平均结果。除了以毫秒为单位的绝对推理时间外,表3还报告了函数评估次数(NFE),以便更好地进行比较。结果表明,在两种情况下,VDD均显著快于原始扩散模型,即使扩散模型仅执行一次去噪步骤也是如此。为了公平比较,所有方法使用了相同数量的Transformer层。预测均在同一系统上进行(RTX 3070 GPU,Intel i7-12700 CPU)。
5.5 消融研究
我们通过报告在四个不同随机种子下任务性能和任务熵的平均值,评估VDD关键特性在不同环境中的重要性。
专家数量影响任务熵。我们首先在“躲避”任务上固定其他所有超参数,仅改变MoE模型的专家数量。图3a展示了使用VDD训练的MoE模型的平均任务成功率和任务熵。除单个专家情况(即高斯策略)外,所有专家数量下的成功率几乎保持高位,且单专家情况下的成功率略高。然而,单个专家只能覆盖多模态行为空间中的单一模式,因此任务熵为0。随着专家数量增加,任务熵随之上升,并在轻微下降后趋于收敛。
训练门控分布可提升性能。图3b展示了在训练参数化门控网络 qξ(z|s)(红色)与将选择专家z的概率固定为 q(z) = 1/N(蓝色,其中N为专家数量)时的成功率。尽管训练门控分布在三个不同任务中均提升了成功率,但在三个任务中的两个任务上,任务熵略有下降(见图3c)。这一观察是合理的:带有训练门控的MoE会导致每个专家根据输入进行特化,而固定门控的专家则被迫在所有可能的输入下解决任务。
时间步区间采样可提高任务熵。此处我们探索了公式(16)中引入的不同时间步分布 p(t)。图3d中考虑了多种方法:使用最小时间步,即 p(t) = limₜ→₀ δ(t),其中δ表示狄拉克δ分布;最大时间步 p(t) = δ(T);在[0, T]上的均匀分布;以及在子区间[t₀, t₁] ⊂ [0, T]上的均匀分布,其中区间边界t₀、t₁为超参数。尽管各变体的成功率相近,但区间采样取得了最高的任务熵且成功率非常高。因此,我们将区间时间步采样作为默认设置。这些结果来自“躲避”任务。
5.6 每个专家行为的可视化
我们提供了关于 D3IL 任务套件中“避免”任务的额外可视化,旨在进一步理解 VDD 如何利用各个专家。图 4 展示了根据给定状态下门控分布的似然进行的专家选择,提供了几个关键见解。首先,VDD 有效地提取了具有不同行为的专家,例如,z1 通常向下移动,z2 倾向于向上移动,而 z3 和 z4 倾向于产生水平移动。其次,门控机制有效地停用了大多数状态下的冗余专家,这表明可以使用更多的组件而不会影响性能,因为门控机制停用了冗余专家。最后,使用单一组件(Z = 1)可以在失去行为多样性的代价下实现完美的成功率。相反,使用许多专家可能会略微降低成功率,但增加了行为多样性。这些定性结果与图 3a 中消融研究的定量结果一致。
6 结论
本文提出了变分扩散蒸馏(Variational Diffusion Distillation, VDD),一种新颖的方法,可将扩散模型蒸馏为专家混合模型(MoE)。VDD使MoE能够继承扩散模型的优点,如良好的泛化能力以及对复杂、多模态数据的表征能力,同时避免了扩散模型的缺点,例如推理时间长和似然计算不可行。基于变分目标,VDD推导出一个下界,使得每个专家可以独立优化。该下界带来了稳定的优化过程,并巧妙地利用了预训练得分函数的梯度,从而使整体MoE模型有效继承了扩散模型的特性。在九个复杂的技能学习任务上的评估表明,VDD相较于当前最先进的方法,在保持学习多样化技能能力的同时,实现了相当甚至更优的蒸馏性能。对专家数量的消融实验表明,单个专家已能表现良好,但无法以多样化的方式解决任务。此外,结果还显示,训练门控分布能显著提升VDD的性能,但会降低任务的熵(即行为多样性)。
局限性。由于MoE需要预测上下文相关的均值和协方差,VDD难以直接应用于生成如图像等非常高维的数据。将VDD扩展到图像领域需要进一步改进,例如在潜在空间中进行预测。此外,专家数量必须由用户预先设定。然而,过多的专家可能增加VDD的训练时间,并可能降低后续使用强化学习进行微调时的专家利用率。与其他蒸馏方法类似,VDD的性能受限于原始模型的表现。
未来工作。一个有前景的研究方向是利用扩散“教师”模型的特征来缩短训练时间并提升性能。这可以通过将扩散模型作为骨干网络,并在其上微调一个MoE头部来实现,用于预测各专家的均值和协方差矩阵。扩散模型的时间依赖性可被直接利用,以在多个噪声水平上训练MoE,从而有效消除第4.3节中引入的时间步选择机制。
更广泛的影响。改进和增强模仿学习算法可能使机器人等现实应用变得更加普及,这既带来积极影响,也可能引发负面后果。我们承认,识别这些潜在负面影响的责任在于主权政府。
原文链接:https://openreview.net/pdf?id=iiYadgKHwo
特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.