Navigation

    Gpushare.com

    • Register
    • Login
    • Search
    • Popular
    • Categories
    • Recent
    • Tags

    去掉Attention的Softmax,复杂度降为O(n)

    语音识别与语义处理领域
    1
    1
    37
    Loading More Posts
    • Oldest to Newest
    • Newest to Oldest
    • Most Votes
    Reply
    • Reply as topic
    Log in to reply
    This topic has been deleted. Only users with topic management privileges can see it.
    • 155****7220
      155****7220 last edited by

      众所周知,尽管基于Attention机制的Transformer类模型有着良好的并行性能,但它的空间和时间复杂度都是O(n2)\mathcal{O}(n^2)O(n2)级别的,nnn是序列长度,所以当nnn比较大时Transformer模型的计算量难以承受。近来,也有不少工作致力于降低Transformer模型的计算量,比如模型剪枝、量化、蒸馏等精简技术,又或者修改Attention结构,使得其复杂度能降低到O(nlog⁡⁡n)\mathcal{O}(n\log⁡n)O(nlog⁡n)甚至O(n)\mathcal{O}(n)O(n)

      论文《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》当中提到一种线性化Attention(Linear Attention)的方法,由此引发了我的兴趣,继而阅读了一些相关博客,有一些不错的收获,最后将自己对线性化Attention的理解汇总在此文中

      Attention

      当前最流行的Attention机制当属Scaled-Dot Attention,即
      $$
      \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax\left(\boldsymbol{Q}\boldsymbol{K}^{\top}\right)\boldsymbol{V}\tag{1}\end{equation}
      $$
      这里的Q∈Rn×dk,K∈Rm×dk,V∈Rm×dv\boldsymbol{Q}\in \mathbb{R}^{n\times d_k}, \boldsymbol{K}\in \mathbb{R}^{m\times d_k}, \boldsymbol{V}\in \mathbb{R}^{m\times d_v}Q∈Rn×dk​,K∈Rm×dk​,V∈Rm×dv​,简单起见我就没显示的写出Attention的缩放因子1d\frac{1}{\sqrt{d}}d​1​了。本文我们主要关心Self Attention的场景,所以为了介绍上的方便,统一设Q,K,V∈Rn×d\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}\in \mathbb{R}^{n\times d}Q,K,V∈Rn×d

      摘掉Softmax

      读者也许想不到,制约Attention性能的关键因素,其实是定义里边的Softmax!事实上,简单地推导一下就可以得到这个结论。QKTQK^TQKT这一步我们得到一个n×nn\times nn×n的矩阵,之后还要做一个Softmax

      对一个1×n1\times n1×n的行向量进行Softmax,时间复杂度是O(n)O(n)O(n),但是对一个n×nn\times nn×n矩阵的每一行做一个Softmax,时间复杂度就是O(n2)O(n^2)O(n2)

      如果没有Softmax,那么Attention的公式就变为三个矩阵连乘QK⊤V\boldsymbol{QK^{\top}V}QK⊤V,而矩阵乘法是满足结合率的,所以我们可以先算K⊤V\boldsymbol{K^{\top}V}K⊤V,得到一个d×dd\times dd×d的矩阵(这一步的时间复杂度是O(d2n)O(d^2n)O(d2n)),然后再用QQQ左乘它(这一步的时间复杂度是O(d2n)O(d^2n)O(d2n)),由于d≪nd \ll nd≪n,所以这样算大致的时间复杂度只是O(n)O(n)O(n)

      对于BERT base来说,d=64d=64d=64而不是768,why?因为768实际上是通过Multi-Head拼接得到的,而每个head的d=64d=64d=64

      也就是说,去掉Softmax的Attention复杂度可以降到最理想的线性级别O(n)\mathcal{O}(n)O(n)!这显然就是我们的终极追求:Linear Attention

      一般的定义

      问题是,直接去掉Softmax还能算是Attention吗?他还能有标准的Attention的效果吗?为了回答这个问题,我们先将Scaled-Dot Attention的定义等价的改写为(本文的向量都是列向量)
      $$
      \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})i = \frac{\sum\limits{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}\boldsymbol{v}j}{\sum\limits{j=1}^n e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}}\tag{2}\end{equation}
      $$

      这里稍微解释下,首先我们知道Q,K∈Rn×d\boldsymbol{Q},\boldsymbol{K}\in \mathbb{R}^{n\times d}Q,K∈Rn×d,令M=Q×K⊤\boldsymbol{M} = \boldsymbol{Q}\times \boldsymbol{K^{\top}}M=Q×K⊤,由矩阵乘法法则可知,M\boldsymbol{M}M的第一行是由Q\boldsymbol{Q}Q的第一行乘以K⊤\boldsymbol{K^{\top}}K⊤的所有列得到的

      Attention(Q,K,V)iAttention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_iAttention(Q,K,V)i​表示最终输出结果矩阵的第iii行

      qi⊤\boldsymbol{q}_i^{\top}qi⊤​表示Q∈Rn×d\boldsymbol{Q}\in \mathbb{R}^{n\times d}Q∈Rn×d矩阵的第iii行(行向量)

      kj\boldsymbol{k}_jkj​表示K⊤∈Rd×n\boldsymbol{K^{\top}}\in \mathbb{R}^{d\times n}K⊤∈Rd×n矩阵的第jjj列(列向量)

      vj\boldsymbol{v}_jvj​表示V⊤∈Rd×nV^{\top}\in \mathbb{R}^{d\times n}V⊤∈Rd×n矩阵的的第jjj列(列向量)

      所以,Scaled-Dot Attention其实就是以eqi⊤kje^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}eqi⊤​kj​为权重对vj\boldsymbol{v}_jvj​做加权平均。所以我们可以提出一个Attention的一般化定义
      $$
      \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})i = \frac{\sum\limits{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}j}{\sum\limits{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)}\tag{3}\end{equation}
      $$
      也就是把eqi⊤kje^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}eqi⊤​kj​换成qi,ki\boldsymbol{q}_i,\boldsymbol{k}_iqi​,ki​的一般函数sim(qi,kj)\text{sim}(\boldsymbol{q}_i,\boldsymbol{k}_j)sim(qi​,kj​),为了保留Attention相似的分布特性,我们要求sim(qi,kj)≥0\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0sim(qi​,kj​)≥0恒成立。也就是说,我们如果要定义新的Attention,必须要保留式(3)的形式,并且满足sim(qi,kj)≥0\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\geq 0sim(qi​,kj​)≥0

      这种一般形式的Attention在CV中也被称为Non-Local网络,出自论文《Non-local Neural Networks》

      几个例子

      如果直接去掉Softmax,那么就是sim(qi,kj)=qi⊤kj\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \boldsymbol{q}_i^{\top}\boldsymbol{k}_jsim(qi​,kj​)=qi⊤​kj​,问题是内积无法保证非负性,所以这还不是一个合理的选择。下面我们介绍几种可取的方案

      值得一提的是,下面介绍的这几种Linear Attention,前两种来自CV领域,第三种是苏剑林大佬构思的(除了下面的介绍外,还有EMANet等CV领域对Attention的改进工作)

      核函数形式

      一个自然的想法是:如果qi,kj\boldsymbol{q}_i, \boldsymbol{k}_jqi​,kj​的每个元素都是非负的,那么内积自然也是非负的。为了完成这点,我们可以给qi,kj\boldsymbol{q}_i, \boldsymbol{k}_jqi​,kj​各自加个激活函数ϕ,φ\phi,\varphiϕ,φ,即
      $$
      \begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\tag{4}\end{equation}
      $$
      其中ϕ(⋅),φ(⋅)\phi(\cdot), \varphi(\cdot)ϕ(⋅),φ(⋅)是值域非负的激活函数。本文开头提到的论文《Transformers are RNNs》选择的是ϕ(x)=φ(x)=elu(x)+1\phi(x)=\varphi(x)=\text{elu}(x)+1ϕ(x)=φ(x)=elu(x)+1,其中

      $$
      \text{elu}(x)=\begin{cases}x& \text{if} \ x>0\ \alpha (e^x-1) & \text{if}\ x<0\end{cases}
      $$

      常见的α\alphaα取值为[0.1,0.3][0.1, 0.3][0.1,0.3]

      非要讲故事的话,式(4)可以联想到"核方法",尤其是ϕ=φ\phi=\varphiϕ=φ时,ϕ\phiϕ就相当于一个核函数,而⟨ϕ(qi),ϕ(kj)⟩\langle \phi(\boldsymbol{q}_i), \phi(\boldsymbol{k}_j)\rangle⟨ϕ(qi​),ϕ(kj​)⟩就是通过核函数所定义的内积。这方面的思考可以参考论文《Transformer dissection: An unified understanding for transformer’s attention via the lens of kernel》,此处不做过多延伸

      妙用Softmax

      另一篇更早的文章《Efficient Attention: Attention with Linear Complexities》则给出了一个更有意思的选择。它留意到在QK⊤\boldsymbol{QK^{\top}}QK⊤中,Q,K∈Rn×d\boldsymbol{Q},\boldsymbol{K}\in \mathbb{R}^{n\times d}Q,K∈Rn×d,如果“Q\boldsymbol{Q}Q在ddd那一维是归一化的,并且K\boldsymbol{K}K在nnn那一维是归一化的”,那么QK⊤\boldsymbol{QK^{\top}}QK⊤就是自动满足归一化了,所以它给出的选择是
      $$
      \begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = softmax_2\left(\boldsymbol{Q}\right)softmax_1(\boldsymbol{K})^{\top}\boldsymbol{V}\tag{5}\end{equation}
      $$
      其中softmax1softmax_1softmax1​、softmax2softmax_2softmax2​分别表示在第一个(n)(n)(n)、第二个维度(d)(d)(d)进行Softmax运算。也就是说,这时候我们是各自给Q,K\boldsymbol{Q},\boldsymbol{K}Q,K加Softmax,而不是算完QK⊤\boldsymbol{QK^{\top}}QK⊤之后再加Softmax

      其实可以证明这个形式也是式(4)​的一个特例,此时对应于ϕ(qi)=softmax(qi),φ(kj)=ekj\phi(\boldsymbol{q}_i)=softmax(\boldsymbol{q}_i),\varphi(\boldsymbol{k}_j)=e^{\boldsymbol{k}_j}ϕ(qi​)=softmax(qi​),φ(kj​)=ekj​,读者可以自行推导一下

      苏神的构思

      在这里,苏神给出了一种构思。这个构思的出发点不再是式(4),而是源于我们对原始定义(2)​的泰勒展开。由泰勒展开我们有
      $$
      \begin{equation}e^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j} \approx 1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_j\tag{6}\end{equation}
      $$
      如果qi⊤kj≥−1\boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1qi⊤​kj​≥−1,那么就可以保证右端的非负性,从而可以让sim(qi,kj)=1+qi⊤kj\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)=1 + \boldsymbol{q}_i^{\top}\boldsymbol{k}_jsim(qi​,kj​)=1+qi⊤​kj​。到这里读者可能已经想到了,想要保证qi⊤kj≥−1\boldsymbol{q}_i^{\top}\boldsymbol{k}_j\geq -1qi⊤​kj​≥−1,只需要分别对qi,kj\boldsymbol{q}_i,\boldsymbol{k}_jqi​,kj​做l2l_2l2​归一化。所以,苏神最终提出的方案就是:
      $$
      \begin{equation}\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = 1 + \left( \frac{\boldsymbol{q}_i}{\Vert \boldsymbol{q}_i\Vert}\right)^{\top}\left(\frac{\boldsymbol{k}_j}{\Vert \boldsymbol{k}_j\Vert}\right)\tag{7}\end{equation}
      $$

      若x=[x1,x2,…,xn]\boldsymbol{x}=[x_1,x_2,…,x_n]x=[x1​,x2​,…,xn​],则∥x∥=x12+x22+⋅⋅⋅+xn2\Vert x\Vert=\sqrt{x_1^2+x_2^2+···+x_n^2}∥x∥=x12​+x22​+⋅⋅⋅+xn2​​

      这不同于式(4),但理论上它更加接近原始的Scaled-Dot Attention

      实现

      这里主要是针对苏神所提出的方法进行实现,但是由于笔者本人水平有限,因此最终实现的代码当中其实存在一些问题,主要是:

      1. 从测试结果来看,改进后的计算速度并没有提升
      2. 无法做到求和为1

      代码实现主要是针对BERT的PyTorch实现这篇文章的代码,更具体的说,其实仅修改了ScaledDotProductAttention这个函数,因此下面只放出这部分代码

      class ScaledDotProductAttention(nn.Module):
          def __init__(self):
              super(ScaledDotProductAttention, self).__init__()
      
          def forward(self, Q, K, V, attn_mask):
              Q = F.normalize(Q, dim=3)
              K = F.normalize(K, dim=3)
              M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2))) # scores : [batch_size, n_heads, seq_len, seq_len]
              M_sum = torch.sum(M, dim=3)
              M = M / M_sum.unsqueeze(3).repeat(1, 1, 1, M.shape[3])
              attn = M.masked_fill(attn_mask, 0) # Fills elements of self tensor with value where mask is one.
              context = torch.matmul(attn, V)
              return context
      

      如果您有更好的实现方法,还望不吝赐教

      Reference

      • 线性Attention的探索:Attention必须有个Softmax吗?
      1 Reply Last reply Reply Quote 1
      • First post
        Last post