Structure learning with Temporal Gaussian Mixture for model-based Reinforcement Learning
基于模型的强化学习中时间高斯混合结构学习
https://arxiv.org/pdf/2411.11511
摘要
基于模型的强化学习是一类能够通过构建环境显式模型实现样本高效决策的方法,该模型可用于学习最优策略。本文提出了一种时域高斯混合模型,由感知模型和转移模型组成。感知模型利用变分高斯混合似然从连续观测中提取离散(潜在)状态。值得注意的是,我们的模型会持续监控收集的数据以搜索新的高斯分量,即感知模型在学习混合模型中高斯分量数量的过程中执行了一种结构学习(Smith 等人,2020;Friston 等人,2018;Neacsu 等人,2022)。此外,转移模型利用狄利克雷 - 分类共轭性学习连续时间步之间的时间转移关系。感知模型和转移模型均能够遗忘部分数据点,同时将数据提供的信息整合到先验中,从而保证快速变分推理。最后,决策过程采用一种能够从状态信念中学习 Q 值的 Q 学习变体实现。实验表明,该模型能够学习多种迷宫的结构:模型可发现状态数量及状态间的转移概率。此外,利用学习到的 Q 值,智能体能够成功从起始位置导航至迷宫出口。
关键词:结构学习、Q 学习、强化学习、贝叶斯建模、高斯混合
1. 引言
基于模型的强化学习是一种描述智能体应如何与环境交互的理论。更准确地说,智能体维持一个由观测、状态和动作组成的环境模型。当新观测可用时,智能体需要估计环境最可能的状态。这一过程通常被称为感知或推理,可通过最小化变分自由能来实现,在机器学习中这也被称为负证据下界(Blei 等人,2017)。
变分推理是一种近似推理形式,其中真实后验分布由变分分布近似。具体而言,推理通常通过将近似后验限制在一类分布中来变得易于处理,其中每个分布对应不同的参数值集合。推理即指优化这些参数,以使变分自由能最小化。
变分推理中常用的一个重要分布族是指数族(Holland 和 Leinhardt,1981),它包含了大多数知名的分布。例如,高斯分布、伯努利分布和狄利克雷分布都是该族的成员。重要的是,指数族内的所有分布都可以映射到相同的函数形式,这可用于推导模块化推理算法,如变分消息传递(Winn 和 Bishop,2005)。
在先前的这些方法中,生成模型的结构是预先已知的,即智能体已知潜在状态的数量。实际上,这意味着模型需要由领域专家设计,而非由智能体学习得到。然而,对于某些应用,模型可能不可用,或者专家可能仅对问题有部分了解。在这种情况下,学习模型变得至关重要。近期的工作主要关注观测为离散的问题(Smith 等人,2020;Friston 等人,2018;Neacsu 等人,2022)。相比之下,本文解决的是观测为连续的情况。
在本文中,我们假设数据服从多元高斯分布的混合分布。然后,混合的每个分量与一个潜在状态相关联,我们的目标是学习这些状态如何随时间变化。关键挑战在于确定正确的分量数量。这要求模型在必要时添加和删除分量。虽然高斯混合模型(GMM)能够修剪多余的分量,但我们还需要一种方法,在发现新的数据点簇时增加分量数量。为此,我们的智能体持续监控数据以寻找新的数据点簇,当识别到此类簇时,会向混合模型中添加新的高斯分量。
尽管使用变分推理学习模型结构和推断当前状态很有用,但它并未规定应采取哪些动作。因此,下一步包括比较不同策略的质量,在强化学习中,策略的质量是奖励的折扣和。理想情况下,我们希望比较所有可能的策略,并选择质量最高的策略。不幸的是,有效策略的数量会随着规划的时间范围呈指数增长。具体来说,如果智能体在T个时间步中每个时间步可以选择A个动作中的一个,那么可能的策略数量为:。
这种指数级增长使得穷举搜索难以处理。相反,可以使用蒙特卡洛树搜索(Browne 等人,2012)等算法来高效探索有效策略的空间(Champion 等人,2022a,b,c,d)。另一种方法是 Q 学习算法(Mnih 等人,2015),其中智能体不是使用搜索树来搜索策略空间,而是学习在每个状态下执行每个动作的价值。然而,Q 学习要求状态是可观测的。因此,我们提出了一种 Q 学习算法的变体,它可以从状态的信念中学习。
在第 2 节中,我们介绍了变分高斯混合模型,该模型用于学习分量数量并对潜在变量进行推理。第 3 节解释了如何利用分类 - 狄利克雷模型来学习时间转移。然后,在第 4 节中,我们描述了一种优化方法,允许智能体通过将数据点提供的信息整合到先验中,从而遗忘部分数据点。通过遗忘部分数据,智能体减少了所需的内存量,同时加快了推理过程。接下来,在第 5 节中,我们描述了如何调整 Q 学习以处理状态的信念。最后,在第 6 节中,我们对我们的方法进行了实证验证,然后在第 7 节中总结和结束本文。
2. 感知模型
在本节中,我们将讨论感知模型的理论和实现细节。具体而言,我们将描述如何使用均值漂移聚类(mean shift clustering)初始化变分高斯混合(VGM, Variational Gaussian Mixture)模型的参数,然后介绍 VGM 生成模型、变分分布及更新方程。
2.1 均值漂移聚类
2.2 变分高斯混合模型
尽管均值漂移算法能够识别数据点的簇结构,但该算法的可扩展性随数据点数量增加而降低 —— 即对于每个新增数据点,方程(1)都需迭代至收敛。因此,我们需要找到一种在不丢失数据所含信息的前提下丢弃数据点的方法。如后文所示,这可通过变分高斯混合(VGM)模型实现(Bishop and Nasrabadi, 2006)。在以下子节中,我们将讨论变分高斯混合模型的生成模型、变分分布、变分自由能及更新方程。
2.2.1 生成模型
2.2.2 经验先验与变分分布
2.2.3 变分自由能
变分推理的目标是使变分分布尽可能接近真实后验。从数学上讲,我们希望:
2.2.3 变分自由能
变分推理的目标是使变分分布尽可能接近真实后验。从数学上讲,我们希望:
2.2.4 更新方程
3. 转移模型
在本节中,我们将解释如何使用分类 - 狄利克雷模型学习连续时间步之间的转移映射。当与上一节的感知模型结合时,所得模型为时域高斯混合模型(TGM,Temporal Gaussian Mixture),能够学习环境结构,即给定动作时可学习状态数量及状态间转移概率的模型。
3.1 时域高斯混合模型
尽管变分高斯混合模型(见 2.2 节)能够学习数据的静态模型,但无法对时间序列建模。一种潜在解决方案是在每个时间步复制高斯混合似然,并定义任意两个连续时间步之间的转移概率(见图 5 (a))。但如实验所示,如图 6 (a),添加此类转移映射会干扰 VGM 为每个数据点簇学习单一分量的能力。简而言之,我们通过实验发现,先学习似然,再利用潜在变量的后验学习转移映射,比同时学习似然和转移映射的效果更好。
因此,我们为转移映射使用第二个生成模型,模型结构如图 5 (b) 所示。需注意,Z0 和 Z1 的变分分布参数通过迭代 2.2.4 节的更新方程计算,且在转移模型的推理过程中保持固定。
更具体地说,当智能体在环境中执行动作时,会记录观测 X 和执行的动作 A。对于每个观测值 xn∈X,VGM 用于计算对应潜在变量 zn 的最优变分分布。但转移模型依赖于两个观测变量 X0 和 X1(即不依赖于 X)。具体方法是将任意两个连续观测 xn 和 xn+1∈X 分别视为 X0 和 X1 的观测值。因此,zn 和 zn+1∈Z 的变分分布参数可分别用于定义 Z0 和 Z1 的变分后验。类似地,每当 xn 和 xn+1 作为 X0 和 X1 的观测值时,动作 an 成为 A0 的观测值。
3.1.1 生成模型
我们现在聚焦于转移模型的定义。与 2.2.1 节类似,D 的先验是狄利克雷分布,Z0 的先验是由 D 参数化的分类分布:
3.1.2 经验先验、变分分布与更新方程
4. 要遗忘还是不要遗忘
在前面的章节中,我们描述了感知模型和转换模型。重要的是,我们解释了每个模型都提供了两组数据点,即可以被遗忘的数据点和需要保留在记忆中的数据点。可以被遗忘的数据点用于计算经验先验,即仅考虑要遗忘的数据点的后验分布。重要的是,经验先验与先验分布具有相同的形式,因此在推理过程结束后,可以将其用作先验信念,从而有效地将可遗忘数据点提供的信息整合到先验中。我们现在专注于如何决定哪些数据点要遗忘,哪些要保留。
4.1 塑性与稳定性困境
在学习过程中,大脑需要具有可塑性,允许突触连接的变化以及现有突触连接强度的变化。然而,这些变化最好能够保留之前学到的概念,以确保某种形式的稳定性,并避免遗忘有用的信息。这就是可塑性与稳定性的困境(Mermillod 等人,2013)。
我们的模型面临着类似的挑战,因为一般来说,随着代理探索环境,新的组件会逐渐被发现。这些新组件需要在不遗忘旧组件的情况下进行学习。例如,想象一只老鼠在迷宫中移动并观察其位置的噪声估计。随着新的迷宫单元被探索,新的位置被采样,新的组件将变得可见。然而,如果新组件的样本数量太少,这个组件可能会与附近的单元混淆,直到有足够的数据点可以区分这些单元。
4.2 灵活组件与固定组件
在本节中,我们讨论我们的模型如何解决可塑性与稳定性的困境。更具体地说,我们假设高斯组件可以是灵活的(可塑的)或固定的(稳定的)。所有组件最初都是灵活的,如果它们在较长时间内持续存在,它们就会变得固定。图 8 展示了组件如何转变为固定组件。
4.3 哪些数据点应该被遗忘?
5. 规划与决策
在本节中,我们将解释如何调整 Q 学习算法(Sutton 和 Barto,2018),使其适用于当前状态存在不确定性的问题。首先,我们介绍标准 Q 学习算法,然后对该算法进行改进,使其能够处理状态上的信念(belief)。
5.1 Q 学习算法
5.2 基于信念的 Q 学习
我们该如何调整 Q 学习以使其能够适用于随机状态呢?我们建议按照状态的后验概率对公式 (69) 中的时间差进行缩放,更具体地说:
6. 实验
在本节中,我们在一个迷宫解决任务上验证了我们的方法。代理能够向上、向下、向左和向右移动,但如果所选动作会导致代理撞墙,则代理不会移动。如图 12 所示,代理需要从初始位置(用老鼠表示)导航到目标状态(用奶酪表示)。然后,当目标状态被到达后,代理必须执行“吃”动作。我们根据两个标准研究了我们代理的性能:(i) 其学习各种迷宫结构的能力,以及 (ii) 其通过到达目标状态并吃掉奶酪来解决迷宫的能力。
首先,我们关注模型学习迷宫结构的能力。通过手动检查学习到的组件和转换矩阵,我们识别出了哪些单元格被正确学习了(见图 12 中的绿色和橙色单元格),以及哪些单元格没有被学习(见图 12 中的红色单元格)。重要的是,随着状态数量的增加,一些组件变得不稳定,有效地接管了对应于邻近单元格的数据点,受影响的单元格用橙色标出。
其次,我们研究了我们的方法可以解决哪些迷宫。我们发现,图 12(a)、12(c)、12(d) 和 12(e) 中的迷宫可以通过我们的方法解决。然而,由于缺乏探索,导致无法学习整个迷宫结构,代理无法解决图 12(b) 中的迷宫。同样,图 12(f) 中的迷宫也未能解决,因为大量的状态使得高斯组件变得不稳定,随着时间的推移,一些组件未能被学习。
另一个有趣的问题是 TGM 代理学习解决每个迷宫的速度有多快。为了回答这个问题,我们依赖于 TGM 与两种强化学习算法之间的经验比较,即深度 Q 网络(DQN)和优势演员评论家(A2C)。简单来说,DQN 代理(Mnih 等人,2015)学习在每个状态下采取每个动作的价值,而 A2C(Mnih 等人,2016)是一种基于策略的方法,配备了评论家网络。请注意,DQN、A2C 和 PPO 都是无模型的方法,它们学习从状态到动作(或价值)的映射,而不会创建环境的模型。
图 13 展示了 TGM、DQN 和 A2C 代理在图 12 对应迷宫中收集的平均单集奖励。图 13 显示,在图 12(a) 的迷宫中,TGM 的表现优于 DQN 和 A2C。有趣的是,所有三种代理(即 TGM、DQN 和 A2C)都未能解决图 12(b) 的迷宫。这表明,尽管拓扑结构相当简单(即一个长走廊),但这种环境具有挑战性,需要一种特殊的探索策略。换句话说,随机选择动作极不可能使代理到达目标状态,因此代理应该积极寻求探索环境中较少访问的部分。
此外,在图 12(c) 的迷宫中,DQN 似乎比 TGM 学习得更快,然而,两种代理最终都达到了类似的渐近平均奖励。对于图 12(d) 的迷宫,情况则相反,TGM 比 DQN 学习得更快,但两种代理最终也达到了类似的渐近平均奖励。请注意,在迷宫 12(c) 和 12(d) 中,A2C 的表现不如 DQN 和 TGM。
最后,在迷宫 12(e) 和 12(f) 中,DQN 的表现优于 TGM 和 A2C。重要的是,在图 13(f) 中,TGM 的方差较小,因为 TGM 代理总是无法解决迷宫 12(f)。相比之下,图 13(e) 中较大的方差表明 TGM 代理在成功解决迷宫和失败之间交替出现。
总之,我们证明了 TGM 代理在与强化学习基准(如 DQN 和 A2C)的竞争中表现良好。更具体地说,TGM 在所有测试的迷宫中都优于 A2C,除了一个所有代理都未能完成任务的迷宫。这可能是因为 TGM 比 A2C 更稳定,如图 13(c) 至 13(e) 所示。同样,TGM 在一个迷宫中的表现优于 DQN,在三个迷宫中的表现与 DQN 相当,在两个迷宫中的表现不如 DQN。然而,尽管 TGM 有时表现不如 DQN,但它学习了一个可解释的环境模型,这在某些应用中可能是可取的。此外,这个环境模型由隐藏状态组成,这些隐藏状态可能有助于开发更高级的探索策略,这些策略旨在探索较少访问的状态。
7 结论
在本文中,我们解决了在观测值为连续变量时从数据中学习模型结构的问题。更具体地说,我们旨在识别数据中存在的簇的数量,每个簇与一个潜在状态相关联。随着智能体收集到更多数据,新的簇会被发现,模型需要增加其状态数量。相反,如果智能体当前认为的状态数量多于数据中实际的簇数量,则需要减少状态数量。
我们提出了一种变分高斯混合模型,其中高斯混合模型能够移除不受数据支持的簇,并且智能体持续监控数据以搜索新的簇。当识别到新簇时,会将相应的组件添加到混合模型中。这两种机制使模型能够学习组件的数量。与此同时,智能体还通过变分推理学习各个组件的参数。
然而,推断当前状态只是其中的一部分,智能体还需要学习不同动作如何影响状态之间的时间转移。这是通过利用分类 - 狄利克雷模型来实现的,该模型有效地记录了智能体在执行特定动作时从一个状态转移到另一个状态的次数。
一旦智能体能够进行推理并预测其动作的后果,最后一个要求就是进行决策和规划。我们提出了一种 Q 学习算法的变体,该变体适应随机状态,即对状态有信念而非可观测状态。这种新算法使智能体能够通过到达出口位置并 “吃奶酪” 来解决多个迷宫问题。
在实验上,我们的方法能够解决多个迷宫问题,但仍然存在局限性。例如,该方法容易遗忘现有组件,其中多个组件会融合成一个组件。此外,由于使用 ε- 贪婪策略来权衡探索和利用,智能体往往难以探索迷宫的偏远部分。这些局限性表明,需要更多的研究来提高高斯组件的稳定性。此外,智能体将受益于改进探索策略,例如,智能体可能应该专注于发现迷宫中很少被探索的部分,而不是随机选择动作。而且,所建模的环境在某种程度上是理想化的,例如,生物体遵循相当刻板的轨迹,从一个簇传递到另一个簇。
此外,人们可以研究学习组件数量的替代方法,如无限高斯混合模型(Rasmussen,1999)。另一个有趣的研究方向是应用时间高斯混合来研究深度主动推理中学习到的潜在表示(Fountas 等人,2020;C¸atal 等人,2020;Champion 等人,2023;Millidge,2020;Sancaktar 等人,2020;Lanillos 等人,2020;Oliver 等人,2019;van der Himst 和 Lanillos,2020),或者类似地使用变分高斯混合来研究变分自编码器学习到的表示(Doersch,2016;Higgins 等人,2017;Kingma 和 Welling,2014;Rezende 等人,2014;Bai 等人,2022;Dilokthanakul 等人,2016)。
附录 A:符号与重要性质
在本文中,我们广泛使用了若干定义和性质,本附录对其进行了总结。附录 A.1 提供了各种概率分布的定义,附录 A.2 介绍了概率论相关性质,附录 A.3 则强调了线性代数中的性质。
A.1:概率分布
均值为μ且精度矩阵为Λ的多元高斯分布的概率密度函数(PDF)表示为:
https://arxiv.org/pdf/2411.11511
特别声明:以上内容(如有图片或视频亦包括在内)为自媒体平台“网易号”用户上传并发布,本平台仅提供信息存储服务。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.