领域自适应是解决迁移学习的重要方法,目前的领域自适应方法依赖于原始域和目标域数据进行同步训练。当源域数据不可用且目标域数据不完全可见时,Test-TimeTraining成为一种新的域适配方法。目前关于测试时间训练(TTT)的研究广泛使用了自我监督学习、对比学习和自我训练等方法。然而,如何在真实环境中定义TTT往往被忽略,以至于不同方法之间缺乏可比性。近日,华南理工大学、A*STAR团队和鹏城实验室联合提出了一个系统的TTT问题分类标准。通过区分方法是否具有顺序推理能力(SequentialInference)以及源域训练目标是否需要修改,对当前方法做了详细的分类。同时,提出了一种基于目标域数据的锚定聚类方法,在各种TTT分类下均取得了最高的分类准确率。本文后续对TTT的研究指出了正确的方向,避免了实验设置的混乱。结果是无与伦比的。研究论文已被NeurIPS2022接收。论文:https://arxiv.org/abs/2206.02721代码:https://github.com/Gorilla-Lab-SCUT/TTAC1.Introduction深度学习的成功主要归功于对大量标注数据和训练、测试集独立同分布的假设。一般情况下,在合成数据上训练,然后在真实数据上进行测试时,无法满足上述假设,也称为领域转移。为了缓解这个问题,领域适应(DA)诞生了。现有的DA工作要么需要在训练期间访问源域和目标域的数据,要么同时在多个域上进行训练。前者要求模型在Adaptation训练期间始终可以访问源域数据,而后者需要更昂贵的计算。为了减少对源域数据的依赖,由于隐私问题或存储开销导致源域数据无法访问,无源域数据的无源域适配(Source-FreeDomainAdaptation,SFDA)解决了域适配问题不可访问的源域数据。作者发现SFDA需要在整个目标数据集上进行多轮训练才能达到收敛,而SFDA在面对流数据时无法解决此类问题,需要及时做出推理预测。这种需要及时适应流数据并进行推理和预测的更现实的设置称为测试时间训练(TTT)或测试时间适应(TTA)。作者指出,社区对TTT的定义存在混淆,导致比较不公平。论文将现有的TTT方法归类为两个关键因素:对于流式传输的数据,需要对当前数据做出及时的预测,称为One-PassAdaptation协议(One-PassAdaptation);对于不满足上述设置的Others称为Multi-PassAdaptation。在从头到尾进行推理预测之前,模型可能需要在整个测试集上进行多轮更新。根据源域的训练损失方程是否需要修改,比如引入额外的自监督分支来实现更有效的TTT。本文的目标是解决最现实和最具挑战性的TTT协议,即在不修改训练损失方程的情况下进行单轮自适应。该设置类似于TENT[1]提出的TTA,但不限于使用来自源域的轻量级信息,例如特征统计。鉴于在测试时有效适应TTT的目标,该假设在计算上是有效的并且大大提高了TTT的性能。作者将这个新的TTT协议命名为顺序测试时间训练(sTTT)。除了以上对不同TTT方法的分类外,论文还提出了两种使sTTT更有效和准确的技术:论文提出了Test-TimeAnchoredClustering(TTAC)方法。为了减少错误的伪标签对聚类更新的影响,论文根据网络对样本预测的稳定性和置信度来过滤伪标签。2.方法介绍本文分为四个部分来解释所提出的方法,分别是1)介绍测试时间训练(TTT)的anchoredclustering模块,如图1中AnchoredClustering部分所示;2)介绍用于过滤伪标签的一些策略,如图1中的PseudoLabelFilter部分;3)不同于TTT++[2]中使用L2距离来衡量两个分布之间的距离,作者使用KL散度来衡量两个全局特征分布之间的距离4)引入一种迭代的方法来在测试过程中高效地更新特征统计量-时间训练(TTT)过程。最后,第五小节给出了整个算法的过程代码。第一部分在锚定聚类中,作者首先使用高斯混合对目标域的特征进行建模,其中每个高斯分量代表一个已发现的聚类。然后作者使用源域中每个类别的分布作为目标域分布的锚点进行匹配。这样,测试数据特征可以同时形成聚类,聚类与源域类别相关联,从而实现对目标域的泛化。综上所述,根据类别信息对源域和目标域的特征进行建模:然后通过KL散度度量两个混合高斯分布之间的距离,通过减少KL散度。然而,直接在两个混合高斯分布上求解KL散度没有封闭形式的解,这阻碍了高效梯度优化方法的使用。在这篇论文中,作者在源域和目标域分配相同数量的簇,每个目标域簇分配给一个源域簇,这样整个混合高斯的KL散度解就可以转化为对高斯分布之间的KL散度之和。下式:上式的封闭式解为:式2中,源域集群的参数可以离线采集,由于只使用了轻量级的统计数据,不会造成隐私泄露,只有很小的一部分大量的计算和存储开销。对于目标域中的变量,涉及到伪标签的使用,作者为此设计了一套有效且轻量级的伪标签过滤策略。第二部分伪标签过滤的策略主要分为两部分:1)时间序列一致性预测的过滤:2)基于后验概率的过滤:最后利用过滤后的样本求解目标域的统计量cluster:第一部分第三部分是在anchorclustering中,过滤后的部分样本没有参与目标域的估计。作者还对所有测试样本进行了全局特征对齐,类似于anchorclustering中的聚类方法。这里把所有的样本看成一个整体的cluster,在源域和目标域分别定义,然后最小化KL散度又是TargetAlignmentGlobalFeatureDistribution:PartFour上面三部分都是在介绍一些域对齐的方法,但是在TTT过程中,估计一个目标域的分布并不容易,因为我们无法观察到整个目标域的数据。在前沿工作中,TTT++[2]使用特征队列存储过去的部分样本来计算局部分布以估计整体分布。但这不仅会带来内存开销,还会导致精度和内存之间的权衡。在本文中,作者提出了统计信息的迭代更新以减轻内存开销。具体的迭代更新公式如下:总的来说,整个算法如下面的算法1所示:3.实验结果如前言所述,本文作者非常注重不同TTT下不同方法的公平比较策略。作者根据以下两个关键因素对所有TTT方法进行了分类:1)是否使用单程自适应协议(One-PassAdaptation)和2)修改源域的训练损失方程,记为Y/N以表示是否需要修改源域训练方程,O/M表示单轮自适应或多轮自适应。此外,作者对6个基准数据集进行了充分的对比实验和进一步的分析。如表1所示,TTT++[2]同时出现在N-O和Y-O协议下,因为TTT++[2]多了一个自监督分支,我们不添加N-O下自监督分支的损失协议,并且这个分子的丢失在Y-O下工作正常。TTAC也使用与Y-O下的TTT++[2]相同的自监督分支。从表中可以看出,TTAC在所有TTT协议下的所有数据集下都取得了最好的成绩;在CIFAR10-C和CIFAR100-C数据集上,TTAC取得了3%以上的提升。从表2-表5分别是ImageNet-C、CIFAR10.1、VisDA上的数据,TTAC取得了最好的结果。另外,作者同时在多个TTT协议下进行了严格的消融实验,清楚地看到了各个组件的作用,如表6所示。首先,从L2Dist和KLD的对比可以看出:用KL散度衡量两个分布效果更好;其次,发现如果单独使用AnchoredClustering或者pseudo-labelsupervision,提升只有14%,但是如果结合AnchoredCluster和PseudoLabelFilter,可以看到明显的性能提升29.15%->11.33%。这也说明了各个组成部分结合的必要性和有效性。最后,作者在文末从五个维度全面分析了TTAC,即sTTT下的累积性能(N-O)、TTAC特征的TSNE可视化、独立于源域的TTT分析、测试样本队列和更新轮次的分析和计算开销以挂钟时间测量。还有更多有趣的证明和分析将在文章的附录中展示。4.总结本文仅简单介绍了TTAC工作的贡献点:现有TTT方法的分类与比较,提出的方法,以及在各个TTT协议分类下的实验。论文和附录中会有更详细的讨论和分析。我们希望这项工作可以为TTT方法提供一个公平的基准,未来的研究应该在各自的协议中进行比较。
