当前位置: 首页 > 科技观察

为什么基于树的模型在表格数据上仍然优于深度学习

时间:2023-03-15 13:06:57 科技观察

在这篇文章中,我将详细解释论文《Why do tree-based models still outperform deep learning on tabular data》,它解释了世界各地的机器学习从业者在各种领域观察中使用的方法——基于树的模型比深度学习/神经网络更擅长分析表格数据。论文注释这篇论文做了很多预处理。例如,删除缺失数据之类的事情会影响树的性能,但是随机森林非常适合缺失数据,如果你的数据非常混乱:包含很多特征和维度。RF的稳健性和优势使其优于更容易出现问题的更“高级”解决方案。其余大部分工作非常标准。我个人不太喜欢应用太多预处理技术,因为这会导致丢失数据集的许多细微差别,但论文中采取的步骤基本上会产生相同的数据集。但需要注意的是,在评价最终结果时采用了同样的处理方式。该论文还使用随机搜索进行超参数调整。这也是行业标准,但根据我的经验,贝叶斯搜索更适合在更广泛的搜索空间中进行搜索。知道了这一点,我们就可以深入探讨我们的主要问题——为什么基于树的方法比深度学习更好?1.神经网络偏向过于平滑的解这是作者第一次分享深度学习神经网络无法与随机森林抗衡的原因。简而言之,当涉及到非平滑函数/决策边界时,神经网络很难创建最佳拟合函数。随机森林在怪异/锯齿状/不规则模式下表现更好。如果让我猜原因,可能是神经网络中使用了梯度,而梯度依赖于可微搜索空间,根据定义是平滑的,所以无法区分尖点和一些随机函数。所以我建议学习AI概念,比如进化算法,传统搜索等等更基础的概念,因为这些概念在神经网络失败的各种情况下都能取得很好的效果。有关基于树的方法(RandomForests)和深度学习器之间决策边界差异的更具体示例,请看下图-在附录中,作者对上述可视化进行了说明:在本节中,我们可以看到RandomForest能够学习MLP无法学习的x轴上的不规则模式(对应于日期特征)。我们展示了默认超参数的这种差异,这是神经网络的典型行为,但在实践中很难(尽管并非不可能)找到成功学习这些模式的超参数。2.非信息属性可能影响MLP-like神经网络的另一个重要因素,特别是对于那些同时编码多个关系的大型数据集。向神经网络提供不相关的特征将以糟糕的方式结束(并且您将浪费更多资源来训练您的模型)。这就是为什么花大量时间进行EDA/领域探索如此重要的原因。这将有助于了解功能并确保一切顺利进行。该论文的作者在添加随机和删除无用特征时测试了模型的性能。根据他们的结果,发现了两个有趣的结果。删除大量特征可以减少模型之间的性能差距。这清楚地表明,树模型的一大优势是能够判断特征是否有用,避免无用特征的影响。向数据集添加随机特征显示神经网络的退化比基于树的方法严重得多。ResNet特别受这些无用属性的困扰。transformer的改进可能是因为里面的attention机制会有一定的帮助。对这种现象的一种可能解释是决策树的设计方式。任何上过人工智能课程的人都会知道决策树中信息增益和熵的概念。这使得决策树能够通过比较剩余的特征来选择最佳路径。回到主题,还有最后一件事让RF在表格数据方面表现优于NN。即旋转不变性。3.NNs是旋转不变的,但实际数据并不是神经网络是旋转不变的。这意味着如果对数据集执行旋转操作,它不会改变它们的性能。旋转数据集后,不同模型的性能和排名发生了很大变化,虽然ResNets始终是最差的,但旋转后它保持了原来的性能,而其他所有模型都发生了很大变化。这个现象很有意思:旋转数据集究竟意味着什么?整篇论文都没有详细的细节(我已经联系了作者,会继续跟进这个现象)。如果您有任何想法,请在评论中分享。但是这个操作让我们明白为什么旋转方差很重要。根据作者的说法,采用特征的线性组合(这使得ResNets不变)实际上可能会歪曲特征及其关系。通过对原始数据进行编码来获得最佳数据偏差,这可能会混合具有非常不同统计特性的特征,并且无法通过旋转不变模型恢复,这将为模型提供更好的性能。总之这是一篇非常有趣的论文,虽然深度学习在文本和图像数据集上取得了长足的进步,但它在表格数据上的优势不大。该论文使用来自不同领域的45个数据集进行测试,结果表明即使不考虑其优越的速度,基于树的模型在中等数据(~10K样本)上仍然是最先进的。