鱼和熊掌!BAIR发布了一项关于神经支持决策树的新研究,兼顾了准确性和可解释性。随着深度学习在金融、医疗等领域的不断落地,模型的可解释性成为了一个非常大的痛点,因为这些领域需要能够准确预测并能够解释其行为的模型。然而,众所周知,深度神经网络缺乏可解释性,这就产生了矛盾。可解释人工智能(XAI)试图平衡模型准确性和可解释性之间的矛盾,但XAI在解释决策原因时并不直接解释模型本身。决策树是一种经典的机器学习分类方法。它易于理解和解释,并且可以在中等规模数据上以较低的难度获得更好的模型。之前流行的微软小冰读心术,极有可能用到了决策树。小冰会先让我们想象一个名人(你需要出名),然后问我们15个以内的问题。我们只需要回答是、否或不知道,小冰就可以很快猜出我们的想法是谁。周志华老师曾在《西瓜书》中展示过决策树的示意图:决策树示意图。尽管决策树有很多优点,但历史经验告诉我们,在遇到ImageNet级别的数据时,其性能仍然远不如神经网络。“精准”与“解读”,如何“鱼”与“熊掌”兼得?将两者结合起来会怎样?最近,加州大学伯克利分校和波士顿大学的研究人员对这一想法进行了测试。他们提出了一种神经支持决策树“Neural-backeddecisiontrees”,在ImageNet上取得了75.30%的top-1分类准确率,同时保留了决策树的可解释性,同时达到了目前神经网络所能达到的准确率。该比率比其他基于决策树的图像分类方法高出约14%。BAIR博客地址:https://bair.berkeley.edu/blog/2020/04/23/decisions/论文地址:https://arxiv.org/abs/2004.00221开源项目地址:https://github.com/新提出的alvinwan/neural-backed-decision-trees方法的可解释性如何?我们来看两张图。深度神经网络在OpenAIMicroscope中的可视化是这样的:而论文提出的方法在CIFAR100上的可视化结果是这样的:哪种方法在图像分类上的可解释性更强已经很明显了。决策树的优势和劣势在深度学习风靡一时之前,决策树是准确性和可解释性的基准。下面,我们首先说明决策树的可解释性。如上图所示,这棵决策树不仅给出了输入数据x的预测结果(是“超级汉堡”还是“华夫饼”),还输出了一系列导致最终预测的中间决策.我们可以验证或质疑这些中间决定。然而,在图像分类数据集上,决策树的准确性落后于神经网络40%。神经网络和决策树的组合也表现不佳,甚至无法与CIFAR10数据集上的神经网络相提并论。这种精度缺陷使得可解释性的优势“一文不值”:我们首先需要一个精度高的模型,但这个模型还必须是可解释的。接近神经支持的决策树现在,这个难题终于有了进展。加州大学伯克利分校和波士顿大学的研究人员通过构建既可解释又准确的模型解决了这个问题。研究的重点是将神经网络和决策树结合起来,在使用神经网络进行低级决策的同时保持高层的可解释性。如下图所示,研究人员将这个模型称为“神经支持决策树(NBDT)”,并表示这个模型可以匹配神经网络的准确率,同时保留了决策树的可解释性。在这个图中,每个节点都包含一个神经网络,上图放大以标记这样一个节点及其包含的神经网络。在此NBDT中,预测是通过决策树进行的,保留了高水平的可解释性。但是决策树上的每个节点都有一个神经网络用于做出低级决策。比如上图中神经网络做出的低层决策是“有香肠”还是“没有香肠”。NBDT与决策树一样具有可解释性。并且NBDT可以输出预测结果的中间决策,这比目前的神经网络要好。如下图所示,在预测“狗”的网络中,神经网络可能只输出“狗”,但NBDT可以输出“狗”和其他中间结果(动物、脊索动物、食肉动物等)。此外,NBDT的预测层次轨迹也被可视化,可以说明哪些可能性被拒绝了。同时,NBDT也取得了媲美神经网络的准确率。在CIFAR10、CIFAR100和TinyImageNet200等数据集上,NBDT的准确性接近神经网络(HowNeuralSupportDecisionTreesAreInterpreted的差距)个体预测的辩证论证对于模型从未见过的辩证法而言是最有信息量的。例如,考虑在Zebra上运行的NBDT(如下图所示)。虽然这个模型从未见过斑马,但下图所示的中间决策是正确的——斑马既是动物又是有蹄类动物。对于前所未见的物体,个体预测的合理性至关重要。模型行为的辩证论证此外,研究人员发现,使用NBDT,可解释性会随着准确性的提高而提高。这与文章开头介绍的准确性和可解释性的对立背道而驰,即:NBDT不仅具有准确性和可解释性,而且还要使准确性和可解释性成为同一目标。ResNet10层次结构(左)不如WideResNet层次结构(右)。例如,ResNet10在CIFAR10上的准确度比WideResNet28x10低4%。相应地,较低精度的ResNet^6层次结构(左)将青蛙、猫和飞机组合在一起,意义较小,因为很难找到这三个类共有的视觉特征。相比之下,更高精度的WideResNet层次结构(右)更有意义,将动物与汽车完全分开。所以可以说准确率越高,NBDT越容易解读。理解决策规则决策树中的决策规则在处理低维表格数据时很容易解释,例如,如果盘子里有面包,则分配给适当的孩子(如下所示)。然而,对于像高维图像这样的输入,决策规则并不是那么简单。该模型的决策规则不仅基于对象类型,还基于上下文、形状和颜色等。这个案例展示了如何使用低维表格数据轻松解释决策规则。为了定量解释决策规则,研究人员使用了WordNet3现有的名词层次结构;通过这种层次结构,可以找到类别之间最具体的共享含义。例如,给定类别猫和狗,WordNet将返回哺乳动物。在下图中,研究人员定量验证了这些WordNet假设。左侧从属树(红色箭头)的WordNet假设是Vehicle。右侧的WordNet假设(蓝色箭头)是Animal。值得注意的是,在具有10个类的小型数据集中(例如CIFAR10),研究人员可以找到所有节点的WordNet假设。然而,在具有1000个类的大型数据集中(即ImageNet),WordNet假设只能在节点的子集中找到。HowitWorksNeural-Backed决策树的训练和推理过程可以分解为以下四个步骤:为决策树构建一个称为诱导层次结构“InducedHierarchy”的层;这一层产生一个树监督损失称为“树监督损失”一个独特的损失函数;推理从将样本传递到神经网络主干开始。在最后一个全连接层之前,主干网络是一个神经网络;最后一个全连接层以顺序决策规则的方式运行以结束推理,研究人员将其称为“嵌入式决策规则”。Neural-Backed决策树训练和推理示意图。运行嵌入式决策规则这里首先讨论推理问题。如前所述,NBDT使用神经网络主干为每个样本提取特征。为了便于理解接下来的操作,研究者先构造了一个等价于全连接层的退化决策树,如下图所示:上面生成了一个矩阵-向量乘法,然后变成了一个矩阵-向量乘法的内积向量,在这里表示为$\hat{y}$。上面输出的最大值的索引就是类别的预测。简单决策树(naivedecisiontree):研究人员构建了一个基本决策树,每个类别只有一个根节点和一个叶节点,如上图“B-Naive”所示。每个叶节点直接连接到根节点并具有表示向量(来自W的行向量)。使用从样本中提取的特征x进行推断意味着计算x与每个子节点的表示向量的内积。与全连接层类似,最大内积的索引就是预测的类。全连接层和简单决策树之间的直接等价关系启发研究人员提出了一种特殊的推理方法——使用内积的决策树。构建归纳级别该级别确定NBDT需要做出决策的类别集。由于预训练神经网络的权重用于构建该层,因此研究人员将其称为诱导层。具体来说,研究人员将全连接层中的权重矩阵W的每一行向量都视为d维空间中的一个点,如上文“步骤B”所示。接下来,对这些点进行层次聚类。该层次结构遵循连续的聚类。使用TreeSupervisedLoss进行训练考虑上图中的“A-Hard”案例。假设绿色节点对应于Horse类。它只是一个类,它也是一个动物(橙色)。对于结果,也知道到达根节点(蓝色)的样本应该是在右边的动物处。到达节点动物“Animal”的样本也应该再次右转到“Horse”。训练的每个节点用于预测正确的子节点。研究人员将强制执行此损失的树木称为树木监督损失。换句话说,这实际上是每个节点的交叉熵损失。使用指南我们可以直接使用Python包管理工具来安装nbdt:pipinstallnbdt安装好nbdt后,你可以在任何图片上进行推断。nbdt支持网页链接或本地图片。nbdthttps://images.pexels.com/photos/126407/pexels-photo-126407.jpeg?auto=compress&cs=tinysrgb&dpr=2&w=32#ORrunonalocalimagenbdt/imaginary/path/to/local/image.png没关系如果不想安装,研究人员为我们提供了网页版demo和Colab示例,地址如下:Demo:http://nbdt.alvinwan.com/demo/Colab:http://nbdt.alvinwan.com/notebook/Thecodebelowshowshow使用研究者提供的预训练模型进行推断:fromnbdt.modelimportSoftNBDTfromnbdt.modelsimportResNet18,wrn28_10_cifar10,wrn28_10_cifar100,wrn28_10#usewrn28_10forTinyImagenet200model=wrn28_10_cifar10()model=SoftNBDT(pretrained=True,dataset='CIFAR10',arch='wrn28_10_cifar10',model=model)此外,研究人员还提供了如何用不到6行代码将nbdt与我们自己的神经网络相结合,详见他们的GitHub开源项目。
