当前位置: 首页 > 科技观察

为什么基于树的模型仍然优于表格数据的深度学习?

时间:2023-03-21 20:42:37 科技观察

深度学习在图像、语言甚至音频等领域取得了巨大进步。然而,在处理表格数据时,深度学习表现平平。由于表格数据具有参差不齐、样本量小、极值大等特点,很难找到相应的不变量。基于树的模型不可微,无法与深度学习模块联合训练,因此创建特定于表的深度学习架构是一个非常活跃的研究领域。许多研究声称击败或匹配基于树的模型,但他们的研究遭到了很多质疑。事实上,由于缺乏既定的表格数据学习基准,研究人员在评估他们的方法时有很大的自由度。此外,与其他机器学习子领域的基准相比,大多数在线可用的表格数据集都很小,这使得评估更加困难。为了缓解这些担忧,法国国家信息学与自动化研究所、索邦大学和其他机构的研究人员提出了一个表格数据基准,该基准能够评估最先进的深度学习模型并表明基于树的模型执行在中型表格数据集上效果更好。还是SOTA。对于这一结论,该论文提供了确凿的证据,表明在表格数据上,使用基于树的方法比深度学习(甚至现代架构)更容易实现良好的预测,研究人员探索了原因。论文地址:https://hal.archives-ouvertes.fr/hal-03723551/document值得一提的是,该论文的作者之一是Ga?lVaroquaux,他是Scikit-learn项目的负责人之一。目前该项目已成为GitHub上最受欢迎的机器学习库之一。而Ga?lVaroquaux参与的文章《Scikit-learn: Machine learning in Python》被引用58,949次。本文的贡献可以概括如下:本研究为表格数据(选择45个开放数据集)创建了一个新的基准,并通过OpenML共享这些数据集,使其易于使用。该研究在表格数据的多种设置下比较了深度学习模型和基于树的模型,同时考虑了选择超参数的成本。该研究还分享了随机搜索的原始结果,这将使研究人员能够以低成本测试新算法以获得固定的超参数优化预算。在表格数据上,基于树的模型仍然优于深度学习方法新基准参考了45个表格数据集,选择基准如下:异构列,列应对应不同属性的特征,从而排除图像或信号数据集。维度低,数据集d/n比低于1/10。无效的数据集,删除可用信息很少的数据集。身份证(独立同分布)数据,去除流状数据集或时间序列。真实世界的数据,去除人工数据集但保留一些模拟数据集。数据集不能太小,删除特征太少(<4)和样本太少(<3000)的数据集。删除过于简单的数据集。删除了扑克和国际象棋等游戏的数据集,因为这些数据集都是针对确定性的。在基于树的模型中,研究人员选择了3种SOTA模型:ScikitLearn的RandomForest、GradientBoostingTrees(GBTs)、XGBoost。该研究对以下深度模型进行了基准测试:MLP、Resnet、FTTransformer、SAINT。图1和图2显示了不同类型数据集的基准测试结果实证研究:为什么基于树的模型在表格数据上仍然优于深度学习归纳偏差。基于树的模型在各种超参数选择中击败了神经网络。事实上,处理表格数据的最佳方法有两个共同特性:它们是集成方法,bagging(随机森林)或boosting(XGBoost,GBT),这些方法中使用的弱学习器是决策树。发现1:神经网络(NN)倾向于过度平滑的解决方案会影响NN。这些结果表明,与基于树的模型相比,数据集中的目标函数并不平滑,神经网络难以适应这些不规则函数。这与Rahaman等人的发现一致,他们发现神经网络偏向于低频函数。基于决策树的模型学习没有这种偏差的分段常数函数。Finding2:Non-informativefeaturescanmoreaffectMLP-likeNN表数据集包含很多非信息(uninformative)特征,对于每个数据集,研究会根据特征的重要性选择丢弃一定比例的特征(通常根据随机森林排序)。从图4可以看出,去掉一半以上的特征对GBT的分类准确率影响不大。图5显示去除非信息特征(5a)减少了MLP(Resnet)和其他模型(FTTransformers和基于树的模型)之间的性能差距,而添加非信息特征扩大了差距,这表明MLP具有对非信息特征的显着影响不太稳健。在图5a中,当研究者移除较大比例的特征时,相应的有用信息特征也被移除。图5b表明去除这些特征造成的精度损失可以通过去除非信息特征来补偿,这对MLP比其他模型更有帮助(同时本研究也去除了冗余特征,不会影响模型表现)。发现3:数据在旋转时是非不变的为什么MLP比其他模型更容易受到无信息特征的影响?一个答案是MLP是旋转不变的:当对训练和测试集特征应用旋转时,在训练集上学习MLP并在测试集上对其进行评估的过程是不变的。事实上,任何旋转不变的学习过程都有一个最坏情况的样本复杂度,它至少在不相关特征的数量上呈线性增长。直观上,为了去除无用的特征,旋转不变算法必须首先找到特征的原始方向,然后选择信息最少的特征。图6a显示了将随机旋转应用于数据集时测试精度的变化,确认只有Resnet是旋转不变的。值得注意的是,随机旋转颠倒了性能顺序:结果是NNs在基于树的模型之上,Resnets在FTTransformer之上,这表明旋转不变性是不可取的。事实上,表格数据通常具有单独的含义,例如年龄、体重等。图6b显示,删除每个数据集中最不重要的一半特征(在旋转之前)会降低除Resnets之外的所有模型的性能,但与旋转时相比使用所有功能而不删除功能。,有小幅下降。