跳至主要內容

位置编码总结

Genhiy...大约 5 分钟AITransformer位置编码

提示

本部分内容部分参考自:知乎:十分钟读懂旋转编码(RoPE)open in new window

在做 self-attention 之前,会用词嵌入向量xx计算q,k,vq, k,v向量同时加入位置信息,函数公式表达如下:

qm=fq(xm,m)kn=fk(xn,n)vn=fv(xn,n) \boldsymbol{q}_m=f_q(\boldsymbol{x}_m,m) \boldsymbol{k}_n=f_k(\boldsymbol{x}_n,n) \boldsymbol{v}_n=f_v(\boldsymbol{x}_n,n)

而基于 transformer 的位置编码方法都是着重于构造一个合适的f(q,k,v)f(q, k,v)函数形式。

绝对位置编码

对于位置编码,常规的做法是在计算 query, key 和 value 向量之前,会计算一个位置编码向量pip_i加到词嵌入xix_i上,位置编码向量pip_i同样也是dd维向量,然后再乘以对应的变换矩阵WW

ft:t{q,k,v}(xi,i):=Wt:t{q,k,v}(xi+pi) f_{t:t\in\{q,k,v\}}(\boldsymbol{x}_i,i):=\boldsymbol{W}_{t:t\in\{q,k,v\}}(\boldsymbol{x}_i+\boldsymbol{p}_i)

而经典的位置编码向量pip_i的计算方式是使用 Sinusoidal 函数:

pi,2t=sin(k/100002t/d)pi,2t+1=cos(k/100002t/d) \begin{aligned}&\boldsymbol{p}_{i,2t}=\sin\left(k/10000^{2t/d}\right)\\&\boldsymbol{p}_{i,2t+1}=\cos\left(k/10000^{2t/d}\right)\end{aligned}

旋转位置编码RoPE

旋转位置编码(Rotary Position Embedding,RoPE)是一种能够将相对位置信息依赖集成到 self-attention 中并提升 transformer 架构性能的位置编码方式。而目前很火的 LLaMA、GLM 模型也是采用该位置编码方式。和相对位置编码相比,RoPE 具有更好的外推性,目前是大模型相对位置编码中应用最广的方式之一。

备注:什么是大模型外推性?

外推性是指大模型在训练时和预测时的输入长度不一致,导致模型的泛化能力下降的问题。例如,如果一个模型在训练时只使用了512个 token 的文本,那么在预测时如果输入超过512个 token,模型可能无法正确处理。这就限制了大模型在处理长文本或多轮对话等任务时的效果。

为了能利用上 token 之间的相对位置信息,假定 query 向量和 key 向量之间的内积操作可以被一个函数gg表示,该函数gg的输入是词嵌入向量xmx_mxnx_n和它们之间的相对位置mnm-n

fq(xm,m),fk(xn,n)=g(xm,xn,mn) \langle\boldsymbol{f}_q(\boldsymbol{x}_m,m),f_k(\boldsymbol{x}_n,n)\rangle=g(\boldsymbol{x}_m,\boldsymbol{x}_n,m-n)

接下来的目标就是找到一个等价的位置编码方式,从而使得上述关系成立。

假定现在词嵌入向量的维度是两维,这样就可以利用上2维度平面上的向量的几何性质,然后论文中提出了一个满足上述关系的ffgg的形式如下:

fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ] \begin{aligned}&f_{q}(\boldsymbol{x}_{m},m)=\left(\boldsymbol{W}_{q}\boldsymbol{x}_{m}\right)e^{im\theta}\\&f_{k}(\boldsymbol{x}_{n},n)=(\boldsymbol{W}_{k}\boldsymbol{x}_{n})e^{in\theta}\\&g(\boldsymbol{x}_{m},\boldsymbol{x}_{n},m-n)=\mathrm{Re}\left[(\boldsymbol{W}_{q}\boldsymbol{x}_{m})(\boldsymbol{W}_{k}\boldsymbol{x}_{n})^{*}e^{i(m-n)\theta}\right]\end{aligned}

fq(xm,m)=(cosmθsinmθ)sinmθcosmθ)(Wq(1,1)Wq(1,2)Wq(2,1)Wq(2,2))(xm(1)xm(2))=(cosmθsinmθ)sinmθcosmθ)(qm(1)qm(2)) \begin{aligned} f_{q}\left(\boldsymbol{x}_{m},m\right)& \left.=\left(\begin{array}{cc}\cos m\theta&-\sin m\theta)\\\sin m\theta&\cos m\theta\end{array}\right.\right)\left(\begin{array}{cc}W_q^{(1,1)}&W_q^{(1,2)}\\W_q^{(2,1)}&W_q^{(2,2)}\end{array}\right)\left(\begin{array}{c}x_m^{(1)}\\x_m^{(2)}\end{array}\right) \\ &\left.=\left(\begin{array}{cc}\cos m\theta&-\sin m\theta)\\\sin m\theta&\cos m\theta\end{array}\right.\right)\left(\begin{array}{c}q_m^{(1)}\\q_m^{(2)}\end{array}\right) \end{aligned}

g(xm,xn,mn)=(qm(1)qm(2))(cos((mn)θ)sin((mn)θ)sin((mn)θ)cos((mn)θ))(kn(1)kn(2)) \left.g(\boldsymbol{x}_m,\boldsymbol{x}_n,m-n)=\left(\begin{array}{cc}\boldsymbol{q}_m^{(1)}&\boldsymbol{q}_m^{(2)}\end{array}\right.\right)\left(\begin{array}{cc}\cos((m-n)\theta)&-\sin((m-n)\theta)\\\sin((m-n)\theta)&\cos((m-n)\theta)\end{array}\right)\left(\begin{array}{c}k_n^{(1)}\\k_n^{(2)}\end{array}\right)

扩展到多维

RΘ,md=(cosmθ0sinmθ00000sinmθ0cosmθ0000000cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθd/21sinmθd/210000sinmθd/21cosmθd/21)WmΘ={θi=100002(i1)/d,i[1,2,,d/2]} \begin{gathered}\boldsymbol{R}_{\Theta,m}^{d}=\underbrace{\begin{pmatrix}\cos m\theta_0&-\sin m\theta_0&0&0&\cdots&0&0\\\sin m\theta_0&\cos m\theta_0&0&0&\cdots&0&0\\0&0&\cos m\theta_1&-\sin m\theta_1&\cdots&0&0\\0&0&\sin m\theta_1&\cos m\theta_1&\cdots&0&0\\\vdots&\vdots&\vdots&\vdots&\ddots&\vdots&\vdots\\0&0&0&0&\cdots&\cos m\theta_{d/2-1}&-\sin m\theta_{d/2-1}\\0&0&0&0&\cdots&\sin m\theta_{d/2-1}&\cos m\theta_{d/2-1}\end{pmatrix}}_{W_m}\\\Theta=\left\{\theta_i=10000^{-2(i-1)/d},i\in[1,2,\ldots,d/2]\right\}\end{gathered}

由于这个矩阵具有很高的稀疏性,直接用矩阵乘法会很浪费算力,推荐通过下述方式来实现 RoPE:

RΘ,mdx=(x0x1x2x3xd2xd1)(cosmθ0cosmθ0cosmθ0cosmθ1cosmθ1cosmθd/21cosmθd/21)+(x1x0x3x2xd1xd2)(sinmθ0sinmθ0sinmθ1sinmθ1sinmθd/21sinmθd/21) \boldsymbol{R}_{\Theta,m}^d\boldsymbol{x}=\begin{pmatrix}x_0\\x_1\\x_2\\x_3\\\vdots\\x_{d-2}\\x_{d-1}\end{pmatrix}\otimes\begin{pmatrix}\cos m\theta_0\\\cos m\theta_0\\\cos m\theta_0\\\cos m\theta_1\\\cos m\theta_1\\\vdots\\\cos m\theta_{d/2-1}\\\cos m\theta_{d/2-1}\end{pmatrix}+\begin{pmatrix}-x_1\\x_0\\-x_3\\x_2\\\vdots\\-x_{d-1}\\x_{d-2}\end{pmatrix}\otimes\begin{pmatrix}\sin m\theta_0\\\sin m\theta_0\\\sin m\theta_1\\\sin m\theta_1\\\vdots\\\sin m\theta_{d/2-1}\\\sin m\theta_{d/2-1}\end{pmatrix}

其中\otimes是逐位对应相乘,即计算框架中*的运算。从这个实现也可以看到,RoPE可以视为是乘性位置编码的变体。

相对位置编码

位置编码的前世今生:从绝对到相对open in new window

大部分长度外推的工作都在优化相对位置编码,因为绝对只能够关注特定窗口信息,而相对可以关注滑动窗口信息,这更加符合外推性的直觉。

训练式相对位置编码

最基础的训练式相对位置编码方式:本质上是学习相对位置的Embedding,再在Attention计算时将位置信息融入其中。

函数式相对位置编码

函数式相对位置编码,不再需要额外的训练参数,它的本质在于通过给原始的q,kq, k向量经过一系列变换来表示其相对位置。(RoPE)

相对与绝对融合

融合其实也非常简单,在模型的前k层只使用相对位置信息,而在后几层加入绝对位置信息,方式也非常粗暴:直接学习一个绝对位置Embedding,将其作为后面几层的qq,这样就完成了融合!