长码序列Transformer模型优化方法提升长码场景性能长码序列Transformer模型优化方法致力于提升长码场景下的效果和性能。由于自注意力模块的复杂性随序列长度呈指数增长,因此大多数基于编程的预训练语言模型(PPLM)通过序列截断来处理代码序列。SASA方法使self-attention的计算变得稀疏,结合代码的结构特点,从而提高长序列任务的性能,降低内存和计算复杂度。论文:TingtingLiu、ChengyuWang、CenChen、MingGao和AoyingZhou。了解具有结构感知稀疏注意力的长编程语言。SIGIR2022ModelFramework下图是SASA的整体框架:其中SASA主要包括两个阶段:Preprocessingstage和SparseTransformertrainingstage。在预处理阶段,得到两个token之间的交互矩阵,一个是top-k频率矩阵,一个是AST模式矩阵。Top-k频率矩阵是使用代码预训练语言模型在CodeSearchNet语料库上学习到的token之间的注意力交互频率,AST模式矩阵是抽象语法树(AbstractSyntaxTree,AST)解析代码,根据语法树的连接关系得到token之间的交互信息。SparseTransformer训练阶段以TransformerEncoder为基础框架,用结构感知稀疏自注意力代替全自注意力,在满足特定模式的token对之间进行注意力计算,从而降低计算复杂度。SASAsparseattention包括以下四个模块:Slidingwindowattention:只计算滑动窗口内token之间的self-attention,保留局部上下文的特征,计算复杂度为,是序列长度,是滑动窗口大小。globalattention:设置一定的全局token,这些token会和序列中的所有token一起计算,从而得到序列的全局信息,计算复杂度为,即全局token的个数。Top-ksparseattention:Transformer模型中的attention交互是稀疏长尾的。对于每个token,只计算与其attention交互最高的top-ktoken进行attention,复杂度为。AST-awarestructureattention:代码不同于自然语言序列,具有更强的结构特征。通过将代码解析成抽象语法树(AST),然后根据语法树中的连接关系确定attention计算的范围。为了适应现代硬件并行计算的特点,我们将序列分成若干个块,而不是以token为单位进行计算。每个查询块使用滑动窗口块、全局块以及top-k和AST块来计算注意力。整体计算复杂度为b作为块大小。每个稀疏注意力模式对应一个注意力矩阵。以滑动窗口注意力为例,注意力矩阵的计算为:ASA伪代码:实验结果我们使用CodeXGLUE[1]提供的四个任务数据集进行评估,分别是代码克隆检测、缺陷检测、代码搜索、代码总结。我们提取序列长度大于512的数据,组成一个长序列数据集。实验结果如下:从实验结果可以看出,SASA在三个数据集上的性能均明显高于所有Baselines。其中,Roberta-base[2]、CodeBERT[3]、GraphCodeBERT[4]采用截断方式处理长序列,会丢失部分上下文信息。Longformer[5]和BigBird[6]是自然语言处理中用于处理长序列的方法,但没有考虑代码的结构特征,直接迁移到代码任务上效果不佳。为了验证top-ksparseattention和AST-awaresparseattention模块的效果,我们在BigCloneBench和DefectDetection数据集上进行了消融实验。结果如下:sparseattentionmodule不仅提升了长代码任务的性能,还可以大大降低显存占用。相同设备下,SASA可以设置更大的batchsize,而fullself-attention模型面临内存不足的问题。具体显存使用情况如下:SASA是一个稀疏注意力模块,可以迁移到其他基于Transformer的预训练模型处理长序列自然语言处理任务,然后集成到开源框架EasyNLP(https://github.com/alibaba/EasyNLP)并为开源社区做出贡献。论文链接:https://arxiv.org/abs/2205.13730
