当前位置: 首页 > 科技观察

你不能两者兼得吗?清华团队提出高精度可解释分类模型

时间:2023-03-15 18:53:01 科技观察

现有的机器学习分类模型从性能和可解释性两个维度大致可以分为两类:深度学习和集成学习(如随机森林,XGBoost)作为代表性的分类模型分类性能好,但模型复杂度高,可解释性差,而以决策树和逻辑回归为代表的模型可解释性强,但分类性能不理想。清华大学(第一作者王卓,王建勇教授的博士研究生)、华东师范大学(张伟,2016年毕业于清华大学)和山东大学(刘宁,2021年毕业于清华大学)。基于规则表示学习的分类模型RRL。RRL既具有像决策树模型这样的高可解释性,又具有像随机森林和XGBoost这样的集成学习器的分类性能。相关论文已入选NeurIPS2021。论文链接:https://arxiv.org/abs/2109.15103代码链接:https://github.com/12wang3/rrl为了同时获得良好的可解释性和分类性能,论文提出了一种新的分类模型————规则表示学习器(RRL)。RRL能够通过自动学习可解释的非模糊规则来进行数据表示和分类。为了高效地训练不可微的RRL模型,本文提出了一种新的训练方法——梯度嫁接法。通过梯度嫁接,可以使用梯度下降直接优化离散RRL。此外,论文还设计了改进版的逻辑激活函数,不仅提高了RRL的可扩展性,还使其能够端到端地离散化连续特征。在9个小规模和4个大规模数据集上的实验表明,RRL的分类性能明显优于其他可解释方法(如第二届“AI诺贝尔奖”获得者CynthiaRudin教授团队提出的SBRL)),并且可以实现与不可解释的复杂模型类似的分类性能,例如集成学习模型RandomForest和XGBoost、分段线性神经网络PLNN。此外,RRL可以轻松地在分类精度和模型复杂度之间进行权衡,以满足不同场景的需求。研究背景和动机尽管深度神经网络在许多机器学习任务中取得了令人瞩目的成果,但其不可解释的性质仍然使它们受到批评。尽管可以使用代理模型(SurrogateModels)、隐藏层调查(HiddenLayerInvestigation)等事后方法来解释深度网络,但这些方法的保真度、一致性和特异性或多或少存在问题。相比之下,基于规则的模型(Rule-basedModel),如决策树,得益于其透明的内部结构和良好的模型表达能力,在医疗、金融等对模型可解释性要求高的领域仍然发挥着作用,和政治。扮演一个重要角色。然而,传统的基于规则的模型由于其离散的参数和结构而难以优化,特别是在大规模数据集上,这严重限制了基于规则的模型的应用范围。然而,集成模型、软规则和模糊规则等,提高了分类和预测能力,却牺牲了模型的可解释性。为了在更多场景中发挥基于规则模型的优势,迫切需要解决以下问题:如何在保持可解释性的同时提高基于规则模型的可扩展性?图1:传统的基于规则的模型及其扩展模型Rule-basedrepresentationlearner为了解决上述问题,本文提出了一种新的基于规则的模型,基于规则的表示学习器(Rule-basedRepresentationLearner,RRL),对于可解释的分类任务。为了获得良好的模型透明性和表现力,RRL被设计为层次模型(如图2所示),由二值化层、若干逻辑层、线性层以及层与层之间的连接组成:二值化层(BinarizationLayer)用于划分连续值特征。结合逻辑层,可以实现特征的端到端离散化。逻辑层用于自动学习规则表示。每个逻辑层由一个合取层和一个析取层组成。合取范式和析取范式可以用两个逻辑层表示。线性层(LinearLayer)用于输出分类结果。可以更好地拟合数据的线性部分。权重可以用来衡量规则的重要性。SkipConnection用于自动跳过不需要的层。图2:规则表示学习器的示例。虚线框中显示了离散逻辑层及其相应的规则。逻辑层逻辑层使用逻辑规则自动学习数据表示。为实现这一点,逻辑层设计为具有离散和连续版本。两者共享参数,但离散版本用于训练、测试和解释,而连续版本仅用于训练。离散逻辑层逻辑层中的每个节点代表一个逻辑运算,包括合取和析取,层与层之间的边连接表示哪些变量参与运算。离散逻辑层节点对应的逻辑运算如下,其中和分别是合取层和析取层的节点,是邻接矩阵。离散逻辑层的一个具体例子如图2中的虚线框所示。通过学习边的连接,逻辑层可以灵活地表示结合或析取范式的离散分类规则。然而,问题是离散逻辑层虽然具有很好的可解释性,但它不能自导且难以训练,这就是为什么需要相应的连续版本的逻辑层。连续逻辑层连续逻辑层必须是可微的,当对连续逻辑层的参数进行二值化时,可以直接得到其对应的离散逻辑层。为此:将0/1邻接矩阵替换为[0,1]之间的实权矩阵将逻辑运算替换为逻辑激活函数传统的逻辑激活函数(PayaniandFekri,2019)如下,其中和是FetchLayer和ContinuousExtractionLayer分别是连续组合节点。但是,两次pass的大小决定了对最终结果影响的大小。如果=0,则对最终结果没有影响。这两个逻辑激活函数虽然可以更好地模拟具有可导实数运算的逻辑运算,但是它们存在严重的梯度消失问题,无法处理大量特征的情况,可扩展性差。分析逻辑激活函数和对应的导数可以发现,使用乘法模拟逻辑运算是导致梯度消失的主要原因。举个例子,对应的导数如下:因为,当乘法次数很大时(一般指大量的特征或大量的节点),导数结果会趋于0,即出现梯度消失的问题。逻辑激活函数改进了传统的逻辑激活函数,因为它使用连续乘法来模拟逻辑运算,所以在处理很多特征时,会出现梯度消失的问题,严重损害了模型的可扩展性。一个直接的改进思路是使用对数函数将乘法转换为加法。但是,对数函数使得激活函数无法保持逻辑运算的特性。因此,需要一个映射函数,至少需要满足以下三个条件:条件(i)和(ii)用于维持逻辑激活函数的范围和趋势,而条件(iii)需要高-阶无穷小,主要是用来减缓它当时向0的速度。取,所以对逻辑激活函数的改进可以概括为,改进后的逻辑激活函数为:二值化层二值化层主要用于将连续的特征值划分为若干个单元。对于第j个连续值特征,有k个随机下界和k个随机上界对它进行划分,然后得到如下二元向量,其中可以学习到逻辑层的边连接,所以通过组合一个二元值层和逻辑层,模型可以自动选择合适的边界进行特征离散化(二值化),即以端到端的方式对特征进行二值化。例如:当一个conjunctive层节点与and连接时,它表示一个区间当一个disjunctive层节点与and连接时,它表示一个区间搜索离散值的解决方案仍然是一个巨大的挑战。另外,logistic激活函数的特性导致RRL在离散点的梯度包含的有用信息很少,所以像Straight-ThroughEstimator(STE)这样的方法无法训练RRL。为了高效地训练不可微的RRL,本文提出了一种新的基于梯度的离散模型训练方法——梯度嫁接法。在植物嫁接中(如图3a所示),一株植物的枝条或芽作为接穗,另一株植物的根或茎作为砧木。新植物”。GradientGrafting的灵感来自于植物嫁接方法,将损失函数对离散模型输出的梯度作为接穗,将连续模型输出对模型参数的梯度作为接穗砧木参数的反向传播路径(如图3b所示)令为t时刻的参数,分别为离散模型和连续模型的输出,则:梯度嫁接法利用t时刻的梯度信息连续点和参数空间中的离散点,实现离散模型的直接优化。图3:(a)植物嫁接实例(Chenetal.,2019)。(b)简化计算图梯度嫁接方法,实线箭头和虚线箭头分别表示正向传播和反向传播,绿色箭头代表嫁接梯度,是红色箭头代表梯度的拷贝,嫁接后有一个损失函数和参数之间的反向传播路径。实验论文通过实验评估RRL并回答以下问题:RRL的分类性能和模型复杂度如何?与其他离散模型训练方法相比,GradientGrafting如何收敛?改进后的逻辑激活函数的可扩展性如何?作者在9个小型数据集和4个大型数据集上进行了实验。这些数据集被广泛用于测试模型的分类性能和可解释性。表1总结了这13个数据集的基本信息。可以看出,这13个数据集充分体现了数据的多样性:实例数从178个到102944个,类别数从2个到26个,原始特征个数从4个到4714个。另外,数据集的特征类型和稀疏性各不相同。表1:数据集统计分类效果论文比较了RRL与6个可解释模型和5个复杂模型的分类效果(F1Score),结果如表2所示。其中,C4.5(Quinlan,1993),CART(Breiman,2017年)、可扩展贝叶斯规则列表(SBRL)(Yang等人,2017年)、可证明最优规则列表(CORELS)(Angelino等人,2017年)和概念规则集(CRS)(Wang等人,2020)是一个基于规则的模型,而逻辑回归(LR)(Kleinbaumetal.,2002)是一个线性模型。这六个模型被认为是可解释的。分段线性神经网络(PLNN)(Chuetal.,2018)、支持向量机(SVM)(ScholkopfandSmola,2001)、随机森林(Breiman,2001)、LightGBM(Keetal.,2017)和XGBoost(Chen和Guestrin,2016)被认为是难以解释的复杂模型。PLNN是一类使用分段线性激活函数的多层感知器(MLP)。RF、LightGBM和XGBoost都是集成模型。可以看出,RRL明显优于其他可解释模型,只有LightGBM和XGBoost这两个复杂模型具有可比较的结果。此外,RRL在所有数据集上都取得了不错的成绩,这也证明了RRL良好的可扩展性。表2:各模型在13个数据集上的分类效果(五折交叉验证F1Score)模型复杂度可以解释模型追求在保证可接受精度的前提下尽可能降低模型复杂度。如果模型的分类性能太差,那么再低的模型复杂度也没有意义。因此,从业者真正关心的是模型分类性能与复杂度之间的关系。考虑到规则重用的存在,论文使用边的总数而不是规则的总数来衡量基于规则的模型的复杂性(可解释性)。RRL、CART、CRS和XGBoost的模型复杂度与模型分类效果的关系如图4所示,其中横轴为复杂度,纵轴为分类效果。可以看出,与其他规则模型和集成模型相比,RRL可以更有效地使用规则,即以更低的模型复杂度获得更好的分类结果。结果还表明,RRL可以通过参数设置轻松地在模型复杂性和分类性能之间进行权衡。图4:模型复杂度与RRL和基线模型分类性能的散点图。AblationExperimentDiscreteModelTrainingMethod通过训练具有相同结构的RRL,作者将梯度嫁接方法与STE(Courbariauxetal.,2015,2016)、ProxQuant(Baietal.,2018)和RB(Wangetal.,2020)对这三类离散模型训练方法进行了比较,训练损失函数结果如图5所示。由于RRL本身的特殊结构(即离散点处的梯度信息很少),只有用梯度嫁接训练的RRL可以很好地收敛。改进后的逻辑激活函数改进前后的结果也如图5所示。可以看出,在处理大规模数据时,逻辑激活函数会存在梯度消失的问题,导致不收敛.改进的逻辑激活函数克服了这个问题。图5:梯度嫁接和其他三种离散模型训练方法的训练损失,以及改进前后使用逻辑激活函数的训练损失。示例显示权重分布图6显示了不同正则化系数对应的RRL线性层权重(规则重要性)的分布。当正则化项的系数比较小的时候,RRL生成的规则比较复杂,数量也比较多。但从分布可以看出,绝大部分是绝对权重较小的规则。因此,可以先了解权重较大的重要规则,等对模型整体和数据有了一定了解后,再了解权重较小的规则。当正则化项的系数较大时,RRL的整体复杂度较低,可以直接理解整体模型。图6:不同正则化系数的线性层权重分布。具体规则图7展示了银行营销数据集学习到的一些规则,用于预测用户是否会接受电话营销中的银行贷款。从这些规则中,我们可以直观地看出哪些用户状态和公司行为对销售结果有影响,比如中年已婚低存款用户更有可能接受贷款。银行可以根据这些可解释的规则调整营销策略,以增加销售额。尽管RRL并不是专门为图像分类任务设计的,但得益于其良好的可扩展性,RRL仍然可以通过可视化的方式为图像分类任务提供直观的解释。图8是RRL在fashion-mnist图像数据集上学习到的规则的可视化。由此,我们可以直观地总结出模型的决策模式,比如通过袖长区分T恤和套头衫。图7:RRL在银行营销数据集上学到的一些规则。图8:RRL在fashion-mnist图像数据集上学习的规则的可视化。摘要论文提出了一种新的可扩展分类器,即正则表示学习器(RRL)。RRL能够通过自动学习可解释的非模糊规则来进行数据表示和分类。得益于自身的模型结构设计、梯度嫁接方法,以及使用改进版的逻辑激活函数,RRL不仅具有很强的可扩展性,而且在模型复杂度较低的前提下,也能取得较好的分类效果。RRL的提出不仅使得可解释的规则模型适用于更大的数据规模和更广泛的应用场景,也为从业者提供了一种更好的权衡模型复杂度和分类效果的方法。在未来的工作中,我们会将RRL扩展到非结构化数据,例如图像和文本,以提高此类数据模型的可解释性。