Linear Attention

在标准的self-attention架构中, V ′ = A t t n ⋅ V V'=Attn \cdot V V=AttnV,那么对其中某个 V i ′ V'_i Vi
v ˉ i = ∑ j s o f t m a x i j ( Q K  ⁣ ⊤ d ) ⋅ v j \bar{v}_i=\sum_j softmax_{ij}\left( \frac{QK^{\!\top}}{\sqrt d} \right) \cdot v_j vˉi=jsoftmaxij(d QK)vj
由于 s o f t m a x softmax softmax针对行,于是
v ˉ  ⁣ ⊤ = ∑ j s o f t m a x i j ( Q K  ⁣ ⊤ d ) ⋅ v j  ⁣ ⊤ = ∑ j e x p ( q i  ⁣ ⊤ k j d ) v j  ⁣ ⊤ ∑ j e x p ( q i  ⁣ ⊤ k j d ) \bar v^{\!\top}=\sum_j softmax_{ij} \left( \frac{QK^{\!\top}}{\sqrt d} \right) \cdot v_j^{\!\top} =\frac{\sum_j exp(\frac{q_i^{\!\top} k_j}{\sqrt d})v_j^{\!\top}}{\sum_j exp(\frac{q_i^{\!\top} k_j}{\sqrt d})} vˉ=jsoftmaxij(d QK)vj=jexp(d qikj)jexp(d qikj)vj
为了进行线性化,简化 e x p ( q  ⁣ ⊤ k d ) exp(\frac{q^{\!\top} k}{\sqrt d}) exp(d qk)这个二元函数中的2个变量为独立的,即 e x p ( q  ⁣ ⊤ k d ) = ϕ ( q  ⁣ ⊤ ) ϕ ( k ) exp(\frac{q^{\!\top} k}{\sqrt d})=\phi(q^{\!\top})\phi(k) exp(d qk)=ϕ(q)ϕ(k),其中 ϕ \phi ϕ Q , V Q, V Q,V都是row-wise的。于是
v ˉ i  ⁣ ⊤ = ∑ j ϕ ( q i  ⁣ ⊤ ) ϕ ( k j ) v j  ⁣ ⊤ ∑ j ϕ ( q i  ⁣ ⊤ ) ϕ ( k j ) = ϕ ( q i  ⁣ ⊤ ) ∑ j ϕ ( k j ) v j  ⁣ ⊤ ϕ ( q i  ⁣ ⊤ ) ∑ j ϕ ( k j ) \bar v_i^{\!\top}=\frac{\sum_j\phi(q_i^{\!\top})\phi(k_j)v_j^{\!\top}}{\sum_j\phi(q_i^{\!\top})\phi(k_j)} = \frac{\phi(q_i^{\!\top})\sum_j\phi(k_j)v_j^{\!\top}}{\phi(q_i^{\!\top})\sum_j\phi(k_j)} vˉi=jϕ(qi)ϕ(kj)jϕ(qi)ϕ(kj)vj=ϕ(qi)jϕ(kj)ϕ(qi)jϕ(kj)vj
当考虑mask时,
v ˉ i T = ϕ ( q i  ⁣ ⊤ ) ∑ j = 1 i ϕ ( k j ) v j  ⁣ ⊤ ϕ ( q i  ⁣ ⊤ ) ∑ j = 1 i ϕ ( k j ) = ϕ ( q i  ⁣ ⊤ ) S i ϕ ( q i  ⁣ ⊤ ) Z i \bar v_i^T=\frac{\phi(q_i^{\!\top})\sum_{j=1}^i\phi(k_j)v_j^{\!\top}}{\phi(q_i^{\!\top})\sum_{j=1}^i\phi(k_j)} =\frac{\phi(q_i^{\!\top})S_i}{\phi(q_i^{\!\top})Z_i} vˉiT=ϕ(qi)j=1iϕ(kj)ϕ(qi)j=1iϕ(kj)vj=ϕ(qi)Ziϕ(qi)Si
其中, Z i = ∑ j = 1 i ϕ ( k j ) , S i = ∑ j = 1 i ϕ ( k j ) v j  ⁣ ⊤ Z_i=\sum_{j=1}^i\phi(k_j), S_i=\sum_{j=1}^i\phi(k_j)v_j^{\!\top} Zi=j=1iϕ(kj),Si=j=1iϕ(kj)vj,由此实现了对标准self-attention的线性化近似,以及 Z i , S i Z_i, S_i Zi,Si可以由 Z i − 1 , S i − 1 Z_{i-1}, S_{i-1} Zi1,Si1计算得到,可以通过cache节省计算时间。

DeltaNet

DeltaNet的定义

在上面的linear attention中, S t = S t − 1 + v t k t  ⁣ ⊤ S_t=S_{t-1}+v_tk_t^{\!\top} St=St1+vtkt累加得到——简洁起见,用 k k k代替 ϕ ( k ) \phi(k) ϕ(k)——可以看作memory,每次的 v t , k t v_t, k_t vt,kt更新这个memory。但由于 v t k t  ⁣ ⊤ v_tk_t^{\!\top} vtkt的维度最高位d, S S S能表示的信息最多为 d d d;而简单的累加不会删除早期的记忆,因此无法有效应对序列长度 N > d N\gt d N>d的情形。一个更合理的模式应该是, S S S会在每轮更新中逸出过去的不重要的k-v关联,来为后续新的变量腾出空间。

根据delta update rule,重新定义损失函数和更新规则为 L t ( S ) = 1 2 ∥ S k t − v t ∥ 2          S t = S t − 1 − β t ( S t − 1 k t − v t ) k t  ⁣ ⊤ \mathcal L_t(S) = \frac12\| S k_t - v_t \|^2\ \ \ \ \ \ \ \ S_t=S_{t-1}-\beta_t(S_{t-1}k_t-v_t)k_t^{\!\top} Lt(S)=21Sktvt2        St=St1βt(St1ktvt)kt 在这里, β t \beta_t βt代表学习率, S t − 1 k t S_{t-1}k_t St1kt代表根据当前memory S t − 1 S_{t-1} St1根据新一轮 k t k_t kt对新一轮目标 v t v_t vt的预测。参数更新的目标是消除“预测” S t − 1 k t S_{t-1}k_t St1kt与"目标“ v t v_t vt之间的difference,这也是delta的含义。于是可以验证更新规则 S t = S t − 1 − β t   ∇  ⁣ S t − 1 L t ( S t − 1 ) = S t − 1 − β t   ( S t − 1 k t − v t ) k t  ⁣ ⊤ \begin{equation} S_t = S_{t-1} - \beta_t \, \nabla_{\!S_{t-1}} \mathcal L_t(S_{t-1}) = S_{t-1} - \beta_t \, (S_{t-1} k_t - v_t) k_t^{\!\top} \end{equation} St=St1βtSt1Lt(St1)=St1βt(St1ktvt)kt

DeltaNet的另一种解释

从key-value retrieval的角度,当前key会retrieve到从前的value: v t o l d = S t − 1 k t v_t^{old}=S_{t-1}k_t vtold=St1kt;新的value则由从前的value和当前的value值插值而来: v t n e w = β t v t + ( 1 − β t ) v t o l d \begin{equation} v_t^{new}=\beta_t v_t+(1-\beta_t)v_t^{old} \end{equation} vtnew=βtvt+(1βt)vtold S t = S t − 1 − v t o l d k t  ⁣ ⊤ + v t n e w k t  ⁣ ⊤ \begin{equation} S_t=S_{t-1}-v_t^{old}k_t^{\!\top}+v_t^{new}k_t^{\!\top} \end{equation} St=St1vtoldkt+vtnewkt公式(3)中的减和加分别代表移除旧的无用信息和增添新的信息。将 v t n e w − v t o l d v_t^{new}-v_t^{old} vtnewvtold记作 u t u_t ut,于是根据式(2)有 u t = β t ( v t − v t o l d ) = β t ( v t − S t − 1 k t ) \begin{equation} u_t=\beta_t(v_t-v_t^{old})=\beta_t(v_t-S_{t-1}k_t) \end{equation} ut=βt(vtvtold)=βt(vtSt1kt)

DeltaNet的线性性

接下来证明,DeltaNet的第2种解释与原本的定义是一致的,并且 S t S_t St可以表示为 ∑ i = 1 t u i k i T \sum_{i=1}^t u_ik_i^T i=1tuikiT,进而证明DeltaNet的线性性。

t = 1 t=1 t=1时, S 1 = β 1 v 1 k 1  ⁣ ⊤ S_1=\beta_1v_1k_1^{\!\top} S1=β1v1k1;假设对 t − 1 t-1 t1 S t − 1 = ∑ i = 1 t − 1 u i k i  ⁣ ⊤ S_{t-1}=\sum_{i=1}^{t-1} u_ik_i^{\!\top} St1=i=1t1uiki成立,则对 S t S_t St,参照式(3), S t = S t − 1 − v t o l d k t  ⁣ ⊤ + v t n e w k t  ⁣ ⊤ = S t − 1 + β t ( v t − S t − 1 k t ) k t  ⁣ ⊤ = S t − 1 ( I − β t k t k t  ⁣ ⊤ ) + β t v t k t  ⁣ ⊤ \begin{equation} \begin{aligned} S_t &= S_{t-1} - v_t^{old}k_t^{\!\top} + v_t^{new}k_t^{\!\top} \\[1mm] &= S_{t-1} + \beta_t(v_t-S_{t-1}k_t)k_t^{\!\top} \\[1mm] &= S_{t-1}(I - \beta_t k_t k_t^{\!\top}) + \beta_t v_t k_t^{\!\top} \end{aligned}\end{equation} St=St1vtoldkt+vtnewkt=St1+βt(vtSt1kt)kt=St1(Iβtktkt)+βtvtkt这与最初定义时的式(1)是一致的;进而 S t = S t − 1 + β t ( v t − S t − 1 k t ) k t  ⁣ ⊤ = S t − 1 + u t k t  ⁣ ⊤ \begin{equation} S_t=S_{t-1}+\beta_t\left( v_t-S_{t-1}k_t \right)k_t^{\!\top}=S_{t-1}+u_tk_t^{\!\top} \end{equation} St=St1+βt(vtSt1kt)kt=St1+utkt

DeltaNet的Chunkwise形式

为推导chunkwise parallel形式,首先展开式(5)的循环: S t = β t v t k t  ⁣ ⊤ + S t − 1 ( I − β t k t k t  ⁣ ⊤ ) = β t v t k t  ⁣ ⊤ + ( β t − 1 v t − 1 k t − 1  ⁣ ⊤ + S t − 2 ( I − β t − 1 k t − 1 k t − 1  ⁣ ⊤ ) ) ( I − β t k t k t  ⁣ ⊤ ) = β t v t k t  ⁣ ⊤ + β t − 1 v t − 1 k t − 1  ⁣ ⊤ ( I − β t k t k t  ⁣ ⊤ ) + S t − 2 ( I − β t − 1 k t − 1 k t − 1  ⁣ ⊤ ) ( I − β t k t k t  ⁣ ⊤ ) = ∑ i = 1 t β i v i k i  ⁣ ⊤ ( ∏ j = i + 1 t ( I − β j k j k j  ⁣ ⊤ ) ) \begin{equation}\begin{aligned} S_t &= \beta_t v_t k_t^{\!\top} + S_{t-1}(I - \beta_t k_t k_t^{\!\top}) \\[1mm] &= \beta_t v_t k_t^{\!\top} + \left(\beta_{t-1}v_{t-1}k_{t-1}^{\!\top}+S_{t-2}(I-\beta_{t-1}k_{t-1}k_{t-1}^{\!\top})\right)(I - \beta_t k_t k_t^{\!\top}) \\[1mm] &=\beta_t v_t k_t^{\!\top} + \beta_{t-1}v_{t-1}k_{t-1}^{\!\top}(I - \beta_t k_t k_t^{\!\top})+S_{t-2}(I-\beta_{t-1}k_{t-1}k_{t-1}^{\!\top})(I - \beta_t k_t k_t^{\!\top}) \\[1mm] &=\sum_{i=1}^t\beta_iv_ik_i^{\!\top} \left( \prod_{j=i+1}^{t}\bigl(I - \beta_j k_j k_j^{\!\top}\bigr) \right) \end{aligned}\end{equation} St=βtvtkt+St1(Iβtktkt)=βtvtkt+(βt1vt1kt1+St2(Iβt1kt1kt1))(Iβtktkt)=βtvtkt+βt1vt1kt1(Iβtktkt)+St2(Iβt1kt1kt1)(Iβtktkt)=i=1tβiviki(j=i+1t(Iβjkjkj))在式(7)中,指定 P i j = ∏ t = i j ( I − β t k t k t ⊤ ) ∈ R d × d P_i^j=\prod_{t=i}^j(I-\beta_tk_tk_t^{\top})\in\mathbb{R}^{d\times d} Pij=t=ij(Iβtktkt)Rd×d,以及 H i j = ∑ t = i j β t v t k t ⊤ P t + 1 j H_i^j=\sum_{t=i}^j\beta_tv_tk_t^{\top}P_{t+1}^j Hij=t=ijβtvtktPt+1j,并规定当 i > j i>j i>j P i j = I P_i^j=I Pij=I。从直觉上看, P i j P_i^j Pij是由 S i S_i Si得到 S j S_j Sj的衰减因子。

接着,将整个长为 N N N的序列分为长度为 C C C的块。对于某个分块 [ t + 1 ] [t+1] [t+1],则有 S [ t ] r = ∑ i = 1 t C + r β i v i k i ⊤ ( ∏ j = i + 1 t C + r ( I − β j k j k j  ⁣ ⊤ ) ) = ∑ i = 1 t C β i v i k i ⊤ ∏ j = i + 1 t C ( I − β j k j k j  ⁣ ⊤ ) ∏ j = t C + 1 t C + r ( I − β j k j k j  ⁣ ⊤ )       + ∑ i = t C + 1 t C + r β i v i k i ⊤ ∏ j = i + 1 t C + r ( I − β j k j k j  ⁣ ⊤ ) = S t C P t C + 1 t C + r + S t C + 1 t C + r = S [ t ] 0 P 0 [ t ] r + H [ t ] r \begin{equation}\begin{aligned} S_{[t]}^r &=\sum_{i=1}^{tC+r}\beta_iv_ik_i^\top \left( \prod_{j=i+1}^{tC+r}\bigl(I - \beta_j k_j k_j^{\!\top}\bigr) \right) \\[1mm] &=\sum_{i=1}^{tC}\beta_iv_ik_i^\top \prod_{j=i+1}^{tC}\bigl(I - \beta_j k_j k_j^{\!\top}\bigr) \prod_{j=tC+1}^{tC+r}\bigl(I - \beta_j k_j k_j^{\!\top}\bigr) \\[1mm] & \ \ \ \ \ +\sum_{i=tC+1}^{tC+r}\beta_iv_ik_i^\top \prod_{j=i+1}^{tC+r}\bigl(I - \beta_j k_j k_j^{\!\top}\bigr) \\[[1mm] &= S_{tC}P_{tC+1}^{tC+r}+S_{tC+1}^{tC+r} = S_{[t]}^0P_{0[t]}^r+H_{[t]}^r \end{aligned}\end{equation} S[t]r=i=1tC+rβiviki(j=i+1tC+r(Iβjkjkj))=i=1tCβivikij=i+1tC(Iβjkjkj)j=tC+1tC+r(Iβjkjkj)     +i=tC+1tC+rβivikij=i+1tC+r(Iβjkjkj)=StCPtC+1tC+r+StC+1tC+r=S[t]0P0[t]r+H[t]r式中, S [ t ] i = S t C + i , P 0 [ t ] r = P t C + 1 t C + r , H [ t ] r = H t C + 1 t C + r S_{[t]}^i=S_{tC+i}, P_{0[t]}^r=P_{tC+1}^{tC+r}, H_{[t]}^r=H_{tC+1}^{tC+r} S[t]i=StC+i,P0[t]r=PtC+1tC+r,H[t]r=HtC+1tC+r。为了在实际代码实现线性attention,需要证明可以用累加的方法得到 P [ t ] r , H [ t ] r P_{[t]}^r, H_{[t]}^r P[t]r,H[t]r

Chunkwise线性性的证明

先列出文章提出的公式,如果关注主要思路可以先跳过后面的证明
P [ t ] r = I − ∑ i = 1 r w [ t ] i k [ t ] i ⊤ \begin{equation}P_{[t]}^r=I-\sum_{i=1}^r w_{[t]}^i k_{[t]}^{i\top}\end{equation} P[t]r=Ii=1rw[t]ik[t]i H [ t ] r = ∑ i = 1 r u [ t ] i k [ t ] i ⊤ \begin{equation}H_{[t]}^r=\sum_{i=1}^r u_{[t]}^i k_{[t]}^{i\top}\end{equation} H[t]r=i=1ru[t]ik[t]i w [ t ] r = ∏ i = 1 r − 1 ( I − β [ t ] i k [ t ] i k [ t ] i ⊤ ) β [ t ] r \begin{equation}w_{[t]}^r=\prod_{i=1}^{r-1}\Bigl( I-\beta_{[t]}^i k_{[t]}^i k_{[t]}^{i\top} \Bigr)\beta_{[t]}^r\end{equation} w[t]r=i=1r1(Iβ[t]ik[t]ik[t]i)β[t]r u [ t ] r = β r v r − β r ∑ i = 1 r − 1 β i v i k i ⊤ ( ∏ l = i + 1 r − 1 ( I − β l k l k l ⊤ ) ) k r \begin{equation}u_{[t]}^r=\beta^r v^r - \beta^r \sum_{i=1}^{r-1}\beta^i v^i k^{i\top} \Bigl(\prod_{l=i+1}^{r-1}(I-\beta^l k^l k^{l\top})\Bigr) k^r\end{equation} u[t]r=βrvrβri=1r1βiviki(l=i+1r1(Iβlklkl))kr

  1. 证明(9),即 P [ t ] r P_{[t]}^r P[t]r可以通过 w [ t ] i w_{[t]}^i w[t]i表示为累加形式(以下证明中省略下标 [ t ] [t] [t]
    P r = ∏ ( I − β t k t k t ⊤ )      展开后考虑每一项 k i k i ⊤ 前面乘的因子 = I − ∑ i = 1 r ∏ j = 1 i − 1 ( I − β j k j k j ⊤ ) β i k i k i ⊤ = I − ∑ i = 1 r w i k i ⊤ \begin{aligned} P^r &= \prod\bigl( I-\beta^t k^t k^{t\top} \bigr)\ \ \ \ \text{展开后考虑每一项$k^i k^{i\top}$前面乘的因子} \\[1mm] &= I - \sum_{i=1}^r\prod_{j=1}^{i-1}\bigl(I-\beta^j k^j k^{j\top} \bigr)\beta^i k^i k^{i\top} \\[1mm] &= I - \sum_{i=1}^r w^i k^{i\top} \end{aligned} Pr=(Iβtktkt)    展开后考虑每一项kiki前面乘的因子=Ii=1rj=1i1(Iβjkjkj)βikiki=Ii=1rwiki

  2. 证明(10), w [ t ] r w_{[t]}^r w[t]r可以递归表示
    w r = ∏ i = 1 r − 1 ( I − β i k i k i ⊤ ) β r      类似的考虑每个 k i 前面乘的因子 = β r ( I − ∑ i = 1 r − 1 ∏ j = 1 i − 1 ( I − β j k j k j ⊤ ) β i k i k i ⊤ ) k r = β r ( I − ∑ i = 1 r − 1 w i k i ⊤ ) k r \begin{aligned} w^r &= \prod_{i=1}^{r-1}\Bigl( I-\beta^i k^i k^{i\top} \Bigr)\beta^r \ \ \ \ \text{类似的考虑每个$k^i$前面乘的因子} \\[1mm] &= \beta^r\Bigl( I - \sum_{i=1}^{r-1}\prod{j=1}^{i-1}(I-\beta^j k^j k^{j\top})\beta^i k^i k^{i\top} \Bigr) k^r \\[1mm] &= \beta^r (I-\sum_{i=1}^{r-1}w^i k^{i\top})k^r \end{aligned} wr=i=1r1(Iβikiki)βr    类似的考虑每个ki前面乘的因子=βr(Ii=1r1j=1i1(Iβjkjkj)βikiki)kr=βr(Ii=1r1wiki)kr

  3. 证明(11), H [ t ] r H_{[t]}^r H[t]r可以通过 u [ t ] i u_{[t]}^i u[t]i表示为累加形式
    H r = ∑ i = 1 r β i v i k i ⊤ P i + 1 r = ∑ i = 1 r β i v i k i ⊤ ( ∏ j = i + 1 r ( I − β j k j k j ⊤ ) )    考虑每个 k i k i ⊤ 前的乘子 = ∑ i = 1 r ( β i v i k i ⊤ − ∑ j = 1 i − 1 β j v j k j ⊤ ( ∏ l = j + 1 i − 1 ( I − β l k l k l ⊤ ) ) β i k i k i ⊤ ) = ∑ i = 1 r u i k i ⊤ \begin{aligned} H^r &= \sum_{i=1}^r \beta^i v^i k^{i\top} P_{i+1}^r \\[1mm] &= \sum_{i=1}^r \beta^i v^i k^{i\top} \Bigl( \prod_{j=i+1}^r (I-\beta^j k^j k^{j\top}) \Bigr)\ \ \text{考虑每个$k^ik^{i\top}$前的乘子} \\[1mm] &= \sum_{i=1}^r\biggl( \beta^i v^i k^{i\top} - \sum_{j=1}^{i-1}\beta^j v^j k^{j\top} \Bigl(\prod_{l=j+1}^{i-1}(I-\beta^l k^l k^{l\top})\Bigr) \beta^i k^i k^{i\top}\biggr) \\[1mm] &= \sum_{i=1}^r u^i k^{i\top} \end{aligned} Hr=i=1rβivikiPi+1r=i=1rβiviki(j=i+1r(Iβjkjkj))  考虑每个kiki前的乘子=i=1r(βivikij=1i1βjvjkj(l=j+1i1(Iβlklkl))βikiki)=i=1ruiki

  4. 证明(12), u [ t ] r u_{[t]}^r u[t]r可以递归表示
    u r = β r ( v r − ∑ j = 1 r − 1 β j v j k j ⊤ ( ∏ l = j + 1 r − 1 ( I − β l k l k l ⊤ ) ) k r )    考虑每个 k j k j ⊤ 前的乘子 = β r ( v r − ∑ j = 1 r − 1 ( β j v j k j ⊤ − ∑ i = 1 j − 1 β i v i k i ⊤ ( ∏ l = i + 1 j − 1 ( I − β l k l k l ⊤ ) β j k j k j ⊤ ) ) k r ) = β r ( v r − ∑ j = 1 r − 1 β j ( v j − ∑ i = 1 j − 1 β i v i k i ⊤ ( ∏ l = i + 1 j − 1 ( I − β l k l k l ⊤ ) k j ) ) k j ⊤ k r ) = β r ( v r − ∑ j = 1 r − 1 u j k j ⊤ k r ) \begin{aligned} u^r &= \beta^r \Bigl( v^r - \sum_{j=1}^{r-1}\beta^j v^j k^{j\top} \bigl(\prod_{l=j+1}^{r-1}(I-\beta^l k^l k^{l\top})\bigr) k^r\Bigr) \ \ \text{考虑每个$k^j k^{j\top}$前的乘子} \\[1mm] &= \beta^r\biggl( v^r - \sum_{j=1}^{r-1}\Bigl(\beta^jv^jk^{j\top}-\sum_{i=1}^{j-1}\beta^iv^ik^{i\top}\bigl( \prod_{l=i+1}^{j-1}(I-\beta^lk^lk^{l\top}) \beta^jk^jk^{j\top} \bigr) \Bigr)k^r\biggr) \\[1mm] &= \beta^r\biggl( v^r - \sum_{j=1}^{r-1}\beta^j\Bigl(v^j-\sum_{i=1}^{j-1}\beta^iv^ik^{i\top}\bigl( \prod_{l=i+1}^{j-1}(I-\beta^lk^lk^{l\top}) k^j \bigr) \Bigr)k^{j\top}k^r\biggr) \\[1mm] &= \beta^r (v^r - \sum_{j=1}^{r-1}u^jk^{j\top}k^r) \end{aligned} ur=βr(vrj=1r1βjvjkj(l=j+1r1(Iβlklkl))kr)  考虑每个kjkj前的乘子=βr(vrj=1r1(βjvjkji=1j1βiviki(l=i+1j1(Iβlklkl)βjkjkj))kr)=βr(vrj=1r1βj(vji=1j1βiviki(l=i+1j1(Iβlklkl)kj))kjkr)=βr(vrj=1r1ujkjkr)

Chunkwise DeltaNet的矩阵形式

将上节的parallel形式带入(8),有
S [ t ] r = S [ t ] 0 − ( S [ t ] 0 ∑ i = 1 r w [ t ] i k [ t ] i ⊤ ) + ∑ i = 1 r u [ t ] i k [ t ] i ⊤ = S [ t ] 0 + ∑ i = 1 r ( u [ t ] i − S [ t ] 0 w [ t ] i ) k [ t ] i ⊤ \begin{equation} S_{[t]}^r=S_{[t]}^0-\Bigl( S_{[t]}^0\sum_{i=1}^r w_{[t]}^ik_{[t]}^{i\top} \Bigr)+\sum_{i=1}^r u_{[t]}^ik_{[t]}^{i\top}=S_{[t]}^0+\sum_{i=1}^r\bigl( u_{[t]}^i-S_{[t]}^0w_{[t]}^i \bigr)k_{[t]}^{i\top} \end{equation} S[t]r=S[t]0(S[t]0i=1rw[t]ik[t]i)+i=1ru[t]ik[t]i=S[t]0+i=1r(u[t]iS[t]0w[t]i)k[t]i o [ t ] r = S [ t ] r q [ t ] r = S [ t ] 0 q [ t ] r + ∑ i = 1 r ( u [ t ] i − S [ t ] 0 w [ t ] i ) ( k [ t ] i ⊤ q [ t ] r ) \begin{equation} o_{[t]}^r=S_{[t]}^rq_{[t]}^r=S_{[t]}^0q_{[t]}^r+\sum_{i=1}^r\bigl(u_{[t]}^i-S_{[t]}^0w_{[t]}^i \bigr)(k_{[t]}^{i\top}q_{[t]}^r) \end{equation} o[t]r=S[t]rq[t]r=S[t]0q[t]r+i=1r(u[t]iS[t]0w[t]i)(k[t]iq[t]r)

因此整个chunk表示为
S [ t + 1 ] = S [ t ] + ( U [ t ] − W [ t ] S [ t ] ⊤ ) ⊤ K [ t ] \begin{equation} S_{[t+1]}=S_{[t]}+\bigl(U_{[t]}-W_{[t]}S_{[t]}^\top\bigr)^\top K_{[t]} \end{equation} S[t+1]=S[t]+(U[t]W[t]S[t])K[t] O [ t ] = Q [ t ] S [ t ] ⊤ + ( Q [ t ] K [ t ] ⊤ ⊙ M ) ( U [ t ] − W [ t ] S [ t ] ⊤ ) \begin{equation} O_{[t]}=Q_{[t]}S_{[t]}^\top+(Q_{[t]}K_{[t]}^\top \odot M)\bigl(U_{[t]}-W_{[t]}S_{[t]}^\top\bigr) \end{equation} O[t]=Q[t]S[t]+(Q[t]K[t]M)(U[t]W[t]S[t])其中, M M M是对角不为0的下三角矩阵。

DeltaNet的最终形式

根据式(11)(12), w [ t ] r w_{[t]}^r w[t]r u [ t ] r u_{[t]}^r u[t]r还不能写成张量乘积的形式。为此,进一步使用UT变换,将其写为 W [ t ] = T [ t ] K [ t ] W_{[t]}=T_{[t]}K_{[t]} W[t]=T[t]K[t]的形式。
对于 W [ t ] W_{[t]} W[t],其中的第 r r r行表示为 W [ t ] [ r , : ] = β [ t ] r K [ t ] [ r , : ] − β [ t ] r ∑ i = 1 r − 1 W [ t ] [ i , : ] ( K [ t ] [ i , : ] K [ t ] [ r , : ] ⊤ ) \begin{equation} \mathbf{W}_{[t]}[r, :] = \beta_{[t]}^r \mathbf{K}_{[t]}[r, :] - \beta_{[t]}^r \sum_{i=1}^{r-1} \mathbf{W}_{[t]}[i, :] (\mathbf{K}_{[t]}[i, :] \mathbf{K}_{[t]}[r, :]^\top) \end{equation} W[t][r,:]=β[t]rK[t][r,:]β[t]ri=1r1W[t][i,:](K[t][i,:]K[t][r,:]) B [ t ] = d i a g ( β [ t ] ) , L [ t ] = t r i l ( B [ t ] K [ t ] K [ t ] ⊤ , − 1 ) B_{[t]}=diag(\beta_{[t]}), L_{[t]}=tril(B_{[t]}K_{[t]}K_{[t]}^\top, -1) B[t]=diag(β[t]),L[t]=tril(B[t]K[t]K[t],1) t r i l ( ⋅ , − 1 ) tril(\cdot,-1) tril(,1)表示取矩阵的严格下三角,那么式(17)变为 W [ t ] + L [ t ] W [ t ] = B [ t ] K [ t ] \begin{equation} W_{[t]}+L_{[t]}W_{[t]}=B_{[t]}K_{[t]} \end{equation} W[t]+L[t]W[t]=B[t]K[t] T [ t ] = ( I + L [ t ] ) − 1 B [ t ] T_{[t]}=(I+L_{[t]})^{-1}B_{[t]} T[t]=(I+L[t])1B[t],可以得到 W [ t ] = ( I + L [ t ] ) − 1 B [ t ] K [ t ] = T [ t ] K [ t ] \begin{equation} W_{[t]}=(I+L_{[t]})^{-1}B_{[t]}K_{[t]}=T_{[t]}K_{[t]} \end{equation} W[t]=(I+L[t])1B[t]K[t]=T[t]K[t]类似的,也有 U [ t ] = T [ t ] V [ t ] \begin{equation} U_{[t]}=T_{[t]}V_{[t]} \end{equation} U[t]=T[t]V[t]

于是,chunkwise DeltaNet的整体流程是

  1. 计算 Q , K , V Q, K, V Q,K,V
  2. 计算 T T T
  3. 根据(19)(20)计算 W , U W, U W,U
  4. 对每个chunk,根据(15)(16)计算memory和最终output

Kimi Linear

Kimi Linear在DeltaNet的基础上进一步增加了对角gate D i a g ( α t ) Diag(\alpha_t) Diag(αt),以实现更精细的memory decay和位置信息的控制: S t = β t v t k t ⊤ + S t − 1 D i a g ( α t ) ( I − β t k t k t ⊤ ) \begin{equation} S_t = \beta_t v_t k_t^{\top} + S_{t-1} Diag(\alpha_t) (I - \beta_t k_t k_t^{\top}) \end{equation} St=βtvtkt+St1Diag(αt)(Iβtktkt)这里沿用了DeltaNet文章中的相乘顺序,与Kimi Linear Attention有差别。后面,同样进行了 S [ t ] r = S [ t ] 0 P 0 [ t ] r + H [ t ] r S_{[t]}^r= S_{[t]}^0P_{0[t]}^r+H_{[t]}^r S[t]r=S[t]0P0[t]r+H[t]r的变换,以WY-transform和UV-transform化为chunkwise-parallel格式。

Kimi Linear Attention的部分右面有机会另写一篇来整理。上述内容源于个人读Linear Attention时的推导,如果有谬误欢迎各路大佬指正。

Ref

[1] Katharopoulos, A., Vyas, A., Pappas, N., and Fleuret, F., “Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention”, arXiv e-prints, Art. no. arXiv:2006.16236, 2020. doi:10.48550/arXiv.2006.16236.

[2] Schlag, I., Irie, K., and Schmidhuber, J., “Linear Transformers Are Secretly Fast Weight Programmers”, arXiv e-prints, Art. no. arXiv:2102.11174, 2021. doi:10.48550/arXiv.2102.11174.

[3] Yang, S., Wang, B., Zhang, Y., Shen, Y., and Kim, Y., “Parallelizing Linear Transformers with the Delta Rule over Sequence Length”, arXiv e-prints, Art. no. arXiv:2406.06484, 2024. doi:10.48550/arXiv.2406.06484.

[4]“线性注意力简史:从模仿、创新到反哺”[https://www.spaces.ac.cn/archives/11033]

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐