当前位置: 首页 > 科技观察

斯坦福大学CS博士新作:新Attention速度提升2-4倍,BERT单节点训练最快

时间:2023-03-16 20:00:14 科技观察

一种快速、节省内存的attention算法来了,名字叫FlashAttention。通过减少GPU内存读/写,FlashAttention的运行速度比PyTorch标准注意力快2-4倍,并且需要的内存减少5-20倍。该研究由斯坦福大学和纽约州立大学布法罗分校的研究人员共同完成。由两名斯坦福计算机博士生TriDao和DanFu合着。下面我们介绍论文的具体内容。FlashAttentionTransformer已经成为自然语言处理、图像分类等应用中使用最广泛的架构。随着研究的不断推进,Transformer的尺寸越来越大,越来越深,但是由于Transformer核心self-attention模块的时间复杂度和内存复杂度,仍然很难给Transformer配备更长的上下文。度数是序列长度的二次方。一些研究人员提出了一些近似注意力的方法,旨在减少注意力计算和记忆需求。这些方法包括稀疏近似、低秩近似及其组合。尽管这些方法可以在序列长度方面将计算减少到线性或接近线性,但它们没有显示出超过标准注意力的挂钟加速,因此没有被广泛使用。造成这种情况的主要原因之一是这些研究侧重于减少FLOP(这可能与挂钟速度无关)并且往往忽略内存访问(IO)的开销。在这篇论文中,研究认为注意力算法应该是IO感知的——也就是说,考虑在内存级别之间进行读取和写入。现代GPU比内存更快,并且Transformer中的大多数操作都被内存访问阻塞。当读取和写入数据占用大量执行时,IO感知算法对于类似内存绑定操作至关重要——例如数据库连接、图像处理、数值线性代数等。但是,用于深度学习的常见Python接口,例如与PyTorch和Tensorflow一样,不允许对内存访问进行细粒度控制。论文地址:https://arxiv.org/pdf/2205.14135.pdfGitHub地址:https://github.com/HazyResearch/flash-attention本研究提出了一种新的注意力算法FlashAttention,可以使用更少的内存访问来计算精确的注意力.FlashAttention旨在避免从HBM(高带宽内存)读取和写入注意力矩阵。这需要:(i)在不访问整个输入的情况下计算softmax缩减;(ii)在反向传播过程中不存储中间注意力矩阵。该研究采用两种成熟的技术来应对这些挑战:(i)该研究通过将输入分成块并多次遍历输入块来重新组织注意力计算,从而逐渐执行softmax缩减(也称为平铺);(ii)该研究在前向传递中存储softmax归一化因子,并在反向传递中快速重新计算片上注意力,这比从HBM读取中间注意力矩阵的标准方法更快。该研究在CUDA中实现FlashAttention以实现对内存访问的细粒度控制,并将所有注意力操作融合到单个GPU内核中。即使由于重新计算而增加了FLOP,它运行得更快(在GPT-2上高达7.6倍,图1右面板)并且使用更少的内存(序列长度呈线性),这主要是由于HBM访问大大减少。本研究分析了FlashAttention的IO复杂度,证明它需要