Sidebar Wallpaper
Author Avatar
Libido Knowledge Bank

Cool & Powerful

  • Total written 2 articles
  • Total created 2 tags
  • Total received 3 comments

Table of ContentsCONTENT

Table of Contents

Flash Attention原理

Administrator
2025-04-06 / 1 Comments / 2 Likes / 50 Views / 0 Words

What、Why、Where、How

What: 通过减少IO访问量加速attention计算

Why: Attention计算是memory-bound而非computation-bound

Where: 以往的文章都注重加速计算过程,而本文着力于减少访存消耗

How: 矩阵分块减少中间结果缓存、反向传播时重新计算中间结果

Result: 2-4x speedup, 10-20x memory reduction

arXiv链接:https://arxiv.org/abs/2205.14135

Why we need this?

在一个标准transformer block中,我们执行类似如下的流程

如下图左侧的PyTorch实现版本所示

# l, d: 长度、特征维度
# input: 形状为l*d的输入
# output: 形状为l*d的输出

def attention(input) -> output:
  Q = W_Q * input # l*d
  K = W_K * input # l*d
  V = W_V * input # l*d
  att = Q * K.T # l*l
  att = sigmoid(att)
  att = dropout(att)
  out = att * V # l*d
  out = W_O * out
  return out

对于后GPT-3.5的模型来说,d和支持的length都很大。

在计算这些矩阵的时候,不可能把他们全放在SBM内存中计算。

需要大量的将数据转移到HBM中来腾出空间计算矩阵这一操作。

而正是这一特性导致attention计算是memory-bound而非computation-bound。

How we achieved this?

If softmax doesn't exist...

拿没有softmax的attention计算过程(假设长度length=15)举例子:

由于矩阵较大,中间结果需多次写入显存再读回,导致IO开销极高。如果我们采用分块矩阵的思想:

先拿出Q的前3行,KT的前5列,计算部分attention_score。(这个attn_score的物理意义是前3个词与前5个词的注意力分数)

然后用这个attn_score与V的前5行做计算,计算出partial_output(1)。

需要注意的是:

  1. 前3个词与前5个词的注意力分数只能与前5个词的V矩阵做计算,得出前三个词每个词与前5个词V的线性组合加和

  2. 这意味着得出的结果还剩下前3个词与后10个词的attention_score与矩阵V做计算,而这些结果也需要与目前3*d的结果直接加起来

  3. 所以我们称目前的结果为partial_output(1)

剩下的步骤就是类似地计算出partial_output(2)和partial_output(3)

最后得到output地前三行即为partial_output(1)+partial_output(2)+partial_output(3)

这样矩阵分块的思想,每个分块的计算完全在SRAM中完成,避免中间结果写回显存。

并且可以通过通过异步操作,在计算当前块时预加载下一块数据。减少了大量IO开销,最终达到2-4x speedup, 10-20x memory reduction。

Plug in softmax

现在我们要把attn_score=QKT的外层加上softmax函数,注意到softmax需要attn_score的一整行数据才能计算

而目前我们经过分块,往往无法在内存中储存完整的attn_score(因为如果储存了,那就又和直接计算没区别了),怎么办?

举个例子,如果我们要计算[5, 6, -1, 7, -4, 3]这个数组的softmax:

除去基本的exp计算,我们除以最大值以缓解数值爆炸带来的精度问题

如果要分别计算前半段和后半段:

然后储存p1, p2, l1, l2, max1, max2:

最后乘以max修正项得到全行的softmax S(a)。

其实本质就是一个数学上的小Trick,就是因为引入了减去exp(max)才导致有点小麻烦

如果计算机的世界没有精度问题,那么这个问题根本不需要讨论(储存p1, l1, ... pn, ln然后直接计算softmax)

Pseudocode

辅以论文中的Pseudocode供参考

This is brand new!

这篇文章的思路很新颖,以往关于attention的加速计算都在提升computation的速度,而本文的作者敏锐的观察到attention计算实际上是memory-bound! 进一步利用巧妙的方法将计算放到SBM上进行,完成了本作Flash Attention。

从2025年的视角来看,Flash attention已然是一个广泛使用的方法,时间和大量实践证明了其理论的正确性。

Result

文中提到Flash attention提供了up to 3× faster than the PyTorch implementation。

并且up to 20× more memory efficient than Pytorch的版本。

arXiv链接:https://arxiv.org/abs/2205.14135

2

Comment Area